Skip to content

Commit

Permalink
Add call tracing module (#569)
Browse files Browse the repository at this point in the history
* Add decorator for call tracing

* Add function for applying a decorator to all functions in a module

* Add some comments

* Add function to enable tracing for most scico functions

* Improve trace call detail

* Suppress typing errors

* Rectify oversight

* No apparent reason for method to be dynamic here

* Some improvements

* Add missing copyright statement

* Fix incorrect jit of method

* Add missing type annotations

* Update submodule

* Move trace functions to their own module

* Add some comments

* Extend colour to all args

* Fix handling of static and class methods

* Add trace usage example

* Different colour for function name

* Add display of return values

* Improve placement of register_variable calls

* Trivial edit

* Bug fix

* Add comments and improve some variable names

* Re-write of apply_decorator function in progress

* Improve verbose output

* Add docstrings

* Suppress mypy complaints

* Output format improvement

* Clean up

* Output format improvement

* Add additional colour coding

* Add colorama dependency

* Update change log

* Exclude trace module from coverage

* Add docs

* Minor edit

* Add option for displaying jax array device and sharding information

* Typo fix

* Update change summary

* Improve docs

* Suppress mypy errors

* Address PR review comments
  • Loading branch information
bwohlberg authored Nov 8, 2024
1 parent 6ce6a15 commit 20899df
Show file tree
Hide file tree
Showing 9 changed files with 598 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[run]
source = scico
command_line = -m pytest
omit =
omit =
scico/test/*
scico/plot.py
scico/trace.py

[report]
# Regexes for lines to exclude from consideration
Expand Down
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SCICO Release Notes
Version 0.0.7 (unreleased)
----------------------------

No changes yet.
New module ``scico.trace`` for tracing function/method calls.



Expand Down
1 change: 1 addition & 0 deletions examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-r ../requirements.txt
colorama
colour_demosaicing
svmbir>=0.4.0
astra-toolbox
Expand Down
23 changes: 13 additions & 10 deletions examples/jnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,17 @@ def py_file_to_string(src):
if re.match("^import|^from .* import", line):
import_seen = True
lines.append(line)
# Backtrack through list of lines to find last import statement
n = 1
for line in lines[-2::-1]:
if re.match("^(import|from)", line):
break
else:
n += 1
# Insert notebook plotting config directly after last import statement
lines.insert(-n, "plot.config_notebook_plotting()\n")

if "plot" in "".join(lines):
# Backtrack through list of lines to find last import statement
n = 1
for line in lines[-2::-1]:
if re.match("^(import|from)", line):
break
else:
n += 1
# Insert notebook plotting config directly after last import statement
lines.insert(-n, "plot.config_notebook_plotting()\n")

# Process remainder of source file
for line in srcfile:
Expand All @@ -73,7 +75,8 @@ def py_file_to_string(src):
n += 1
else:
break
lines = lines[0:-n]
if n > 0:
lines = lines[0:-n]

return "".join(lines)

Expand Down
110 changes: 110 additions & 0 deletions examples/scripts/trace_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

r"""
SCICO Call Tracing
==================
This example demonstrates the call tracing functionality provided by the
[trace](../_autosummary/scico.trace.rst) module. It is based on the
[non-negative BPDN example](sparsecode_nn_admm.rst).
"""


import numpy as np

import jax

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, MatrixSubproblemSolver
from scico.trace import register_variable, trace_scico_calls
from scico.util import device_info

"""
Initialize tracing. JIT must be disabled for correct tracing.
The call tracing mechanism prints the name, arguments, and return values
of functions/methods as they are called. Module and class names are
printed in light red, function and method names in dark red, arguments
and return values in light blue, and the names of registered variables
in light yellow. When a method defined in a class is called for an object
of a derived class type, the class of that object is printed in light
magenta, in square brackets. Function names and return values are
distinguished by initial ">>" and "<<" characters respectively.
"""
jax.config.update("jax_disable_jit", True)
trace_scico_calls()


"""
Create random dictionary, reference random sparse representation, and
test signal consisting of the synthesis of the reference sparse
representation.
"""
m = 32 # signal size
n = 128 # dictionary size
s = 10 # sparsity level

np.random.seed(1)
D = np.random.randn(m, n).astype(np.float32)
D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary

xt = np.zeros(n, dtype=np.float32) # true signal
idx = np.random.randint(low=0, high=n, size=s) # support of xt
xt[idx] = np.random.rand(s)
y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal

xt = snp.array(xt) # convert to jax array
y = snp.array(y) # convert to jax array


"""
Register a variable so that it can be referenced by name in the call trace.
Any hashable object and numpy arrays may be registered, but JAX arrays
cannot.
"""
register_variable(D, "D")


"""
Set up the forward operator and ADMM solver object.
"""
lmbda = 1e-1
A = linop.MatrixOperator(D)
register_variable(A, "A")
f = loss.SquaredL2Loss(y=y, A=A)
g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]
C_list = [linop.Identity((n)), linop.Identity((n))]
rho_list = [1.0, 1.0]
maxiter = 1 # number of ADMM iterations (set to small value to simplify trace output)

register_variable(f, "f")
register_variable(g_list[0], "g_list[0]")
register_variable(g_list[1], "g_list[1]")
register_variable(C_list[0], "C_list[0]")
register_variable(C_list[1], "C_list[1]")

solver = ADMM(
f=f,
g_list=g_list,
C_list=C_list,
rho_list=rho_list,
x0=A.adj(y),
maxiter=maxiter,
subproblem_solver=MatrixSubproblemSolver(),
itstat_options={"display": True, "period": 5},
)

register_variable(solver, "solver")


"""
Run the solver.
"""
print(f"Solving on {device_info()}\n")
x = solver.solve()
mse = metric.mse(xt, x)
18 changes: 13 additions & 5 deletions scico/optimize/_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from functools import partial
from typing import Optional, Union

import jax
Expand Down Expand Up @@ -101,15 +102,22 @@ def __init__(
self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of ∇f)
self.fixed_point_residual = snp.inf

def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:
return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)

self.x_step = jax.jit(x_step)

self.x: Union[Array, BlockArray] = x0 # current estimate of solution

super().__init__(**kwargs)

def x_step(self, v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:
"""Compute update for variable `x`."""
return PGM._x_step(self.f, self.g, v, L)

@staticmethod
@partial(jax.jit, static_argnums=(0, 1))
def _x_step(
f: Functional, g: Functional, v: Union[Array, BlockArray], L: float
) -> Union[Array, BlockArray]:
"""Jit-able static method for computing update for variable `x`."""
return g.prox(v - 1.0 / L * f.grad(v), 1.0 / L)

def _working_vars_finite(self) -> bool:
"""Determine where ``NaN`` of ``Inf`` encountered in solve.
Expand Down
1 change: 1 addition & 0 deletions scico/optimize/pgm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down
Loading

0 comments on commit 20899df

Please sign in to comment.