Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More informative structure mismatch error message in Linear layer #92

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions penzai/nn/linear_and_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
from penzai.core import named_axes
from penzai.core import shapecheck
from penzai.core import struct
from penzai.core import variables
from penzai.nn import grouping
from penzai.nn import layer as layer_base
from penzai.nn import parameters

NamedArray = named_axes.NamedArray
Parameter = variables.Parameter
ParameterValue = variables.ParameterValue


class LinearOperatorWeightInitializer(Protocol):
Expand Down Expand Up @@ -421,12 +424,24 @@ class Linear(layer_base.Layer):
def __call__(self, in_array: NamedArray, **_unused_side_inputs) -> NamedArray:
"""Runs the linear operator."""
in_struct = self._input_structure()
dimvars = shapecheck.check_structure(in_array, in_struct)
if isinstance(
self.weights,
Parameter | ParameterValue,
) and self.weights.label.endswith(".weights"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like pytype is raising an error here. I don't think it knows how to refine the type of self.weights. Would you mind either:

  • refactoring this so that you first assign to a local variable weights = self.weights and then change the rest of the function to refer to weights,
  • or just adding a comment # pytype: disable=attribute-error here to tell pytype you know what you're doing?

Copy link
Collaborator

@danieldjohnson danieldjohnson Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You should be able to run the typechecker yourself with uv run pytype --jobs auto penzai as long as the dev dependencies are installed.)

Copy link
Contributor Author

@amifalk amifalk Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, for someone reason locally disabling the error doesn't seem to work

    if isinstance(
        self.weights,
        Parameter | ParameterValue,
    ) and (  # pytype: disable=attribute-error
        self.weights.label.endswith(".weights")
    ):
      error_prefix = (  # pytype: disable=attribute-error
          f"({self.weights.label[:-8]}) "
      )
    else:
      error_prefix = ""
``

and neither does assigning self.weights to weights. Globally ignoring the attribute error does seem to work though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm that's weird. Does it work if you add

# pytype: disable=attribute-error

on a line on its own before the if block, and

# pytype: enable=attribute-error

on a line after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did the trick!

error_prefix = f"({self.weights.label[:-8]}) "
else:
error_prefix = ""

dimvars = shapecheck.check_structure(
in_array, in_struct, error_prefix=error_prefix
)

result = contract(self.in_axis_names, in_array, self.weights.value)

out_struct = self._output_structure()
shapecheck.check_structure(result, out_struct, known_vars=dimvars)
shapecheck.check_structure(
result, out_struct, known_vars=dimvars, error_prefix=error_prefix
)
return result

@classmethod
Expand Down
Loading