diff --git a/robrank/losses/infonce.py b/robrank/losses/infonce.py index 6c5036e..f5aba34 100644 --- a/robrank/losses/infonce.py +++ b/robrank/losses/infonce.py @@ -15,7 +15,7 @@ t=0.2 R@1 = 88.4 ''' -def infonce(repA: th.Tensor, repB: th.Tensor, *, metric:str='C', t:float = 0.2) -> th.Tensor: +def infonce(repA: th.Tensor, repB: th.Tensor, *, metric:str='C', t:float = 0.0) -> th.Tensor: # make sure shape is correct repA, repB = th.flatten(repA, 1), th.flatten(repB, 1) assert metric == 'C'