Skip to content

Commit

Permalink
Don't assume object hashes are unique in caching logic (#215)
Browse files Browse the repository at this point in the history
* Fix cache collision edge cases

* Fuzzy cache collision check

* Remove old cache tests

* Comments

* Python 3.7

* Add back cache test (for coverage)

* More correct cache key

* relax cache assert equality
  • Loading branch information
brentyi authored Dec 19, 2024
1 parent 735530a commit 6e9a049
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 16 deletions.
24 changes: 19 additions & 5 deletions src/tyro/_unsafe_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import sys
from typing import Any, Callable, Dict, List, TypeVar

CallableType = TypeVar("CallableType", bound=Callable)
Expand All @@ -23,11 +24,20 @@ def unsafe_cache(maxsize: int) -> Callable[[CallableType], CallableType]:
def inner(f: CallableType) -> CallableType:
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
key = tuple(unsafe_hash(arg) for arg in args) + tuple(
("__kwarg__", k, unsafe_hash(v)) for k, v in kwargs.items()
key = tuple(_make_key(arg) for arg in args) + tuple(
("__kwarg__", k, _make_key(v)) for k, v in kwargs.items()
)

if key in local_cache:
# Fuzzy check for cache collisions if called from a pytest test.
if "pytest" in sys.modules:
import random

if random.random() < 0.5:
a = f(*args, **kwargs)
b = local_cache[key]
assert a == b or str(a) == str(b)

return local_cache[key]

out = f(*args, **kwargs)
Expand All @@ -41,8 +51,12 @@ def wrapped_f(*args, **kwargs):
return inner


def unsafe_hash(obj: Any) -> Any:
def _make_key(obj: Any) -> Any:
"""Some context: https://github.com/brentyi/tyro/issues/214"""
try:
return hash(obj)
# If the object is hashable, we can use it as a key directly.
hash(obj)
return obj
except TypeError:
return id(obj)
# If the object is not hashable, we'll use assume the type/id are unique...
return type(obj), id(obj)
47 changes: 46 additions & 1 deletion tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
)

import pytest
from typing_extensions import Annotated, Final, Literal, TypeAlias
from typing_extensions import (
Annotated,
Final,
Literal,
Protocol,
TypeAlias,
runtime_checkable,
)

import tyro

Expand Down Expand Up @@ -953,3 +960,41 @@ class NumericTower:
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))


def test_runtime_checkable_edge_case() -> None:
"""From Kevin Black: https://github.com/brentyi/tyro/issues/214"""

@runtime_checkable
class DummyProtocol(Protocol):
pass

@dataclasses.dataclass(frozen=True)
class SubConfigA:
pass

@dataclasses.dataclass(frozen=True)
class SubConfigB:
pass

@dataclasses.dataclass
class Config:
subconfig: DummyProtocol

CONFIGS = {
"a": Config(subconfig=SubConfigA()),
"b": Config(subconfig=SubConfigB()),
}

assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["a"]
).subconfig
== SubConfigA()
)
assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["b"]
).subconfig
== SubConfigB()
)
40 changes: 40 additions & 0 deletions tests/test_py311_generated/test_dcargs_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
List,
Literal,
Optional,
Protocol,
Text,
Tuple,
TypeAlias,
TypeVar,
runtime_checkable,
)

import pytest
Expand Down Expand Up @@ -955,3 +957,41 @@ class NumericTower:
assert tyro.cli(NumericTower, args="--e 3.2".split(" ")) == NumericTower(e=3.2)
with pytest.raises(SystemExit):
tyro.cli(NumericTower, args="--d False".split(" "))


def test_runtime_checkable_edge_case() -> None:
"""From Kevin Black: https://github.com/brentyi/tyro/issues/214"""

@runtime_checkable
class DummyProtocol(Protocol):
pass

@dataclasses.dataclass(frozen=True)
class SubConfigA:
pass

@dataclasses.dataclass(frozen=True)
class SubConfigB:
pass

@dataclasses.dataclass
class Config:
subconfig: DummyProtocol

CONFIGS = {
"a": Config(subconfig=SubConfigA()),
"b": Config(subconfig=SubConfigB()),
}

assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["a"]
).subconfig
== SubConfigA()
)
assert (
tyro.extras.overridable_config_cli(
{k: (k, v) for k, v in CONFIGS.items()}, args=["b"]
).subconfig
== SubConfigB()
)
11 changes: 6 additions & 5 deletions tests/test_py311_generated/test_unsafe_cache_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@ def f(dummy: int):
nonlocal x
x += 1

# >= is because of fuzz testing inside of unsafe_cache
f(0)
f(0)
f(0)
assert x == 1
assert x >= 1
f(1)
f(1)
f(1)
assert x == 2
assert x >= 2
f(0)
f(0)
f(0)
assert x == 2
assert x >= 2
f(2)
f(2)
f(2)
assert x == 3
assert x >= 3
f(0)
f(0)
f(0)
assert x == 4
assert x >= 4
11 changes: 6 additions & 5 deletions tests/test_unsafe_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@ def f(dummy: int):
nonlocal x
x += 1

# >= is because of fuzz testing inside of unsafe_cache
f(0)
f(0)
f(0)
assert x == 1
assert x >= 1
f(1)
f(1)
f(1)
assert x == 2
assert x >= 2
f(0)
f(0)
f(0)
assert x == 2
assert x >= 2
f(2)
f(2)
f(2)
assert x == 3
assert x >= 3
f(0)
f(0)
f(0)
assert x == 4
assert x >= 4

0 comments on commit 6e9a049

Please sign in to comment.