Skip to content

Commit

Permalink
Fix dtypes for qr
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Oct 8, 2024
1 parent c27c28f commit e783c57
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cubed/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# These functions are in both the main and linalg namespaces
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _floating_dtypes
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
matmul,
matrix_transpose,
Expand Down Expand Up @@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult:
if mode != "reduced":
raise ValueError("qr only supports mode='reduced'")

if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in qr")

if x.numblocks[1] > 1:
raise ValueError(
"qr only supports tall-and-skinny (single column chunk) arrays. "
Expand Down Expand Up @@ -80,7 +84,7 @@ def _qr_first_step(A):
nxp.linalg.qr,
A,
shapes=[A.shape, R1_shape],
dtypes=[nxp.float64, nxp.float64],
dtypes=[A.dtype, A.dtype],
chunkss=[A.chunks, R1_chunks],
extra_projected_mem=extra_projected_mem,
)
Expand Down Expand Up @@ -119,7 +123,7 @@ def _qr_second_step(R1):
nxp.linalg.qr,
R1_single,
shapes=[Q2_shape, R2_shape],
dtypes=[nxp.float64, nxp.float64],
dtypes=[R1.dtype, R1.dtype],
chunkss=[Q2_chunks, R2_chunks],
extra_projected_mem=extra_projected_mem,
)
Expand Down Expand Up @@ -153,7 +157,7 @@ def key_function(out_key):
Q1,
Q2,
shapes=[Q1_shape],
dtypes=[nxp.float64],
dtypes=[result_type(Q1, Q2)],
chunkss=[Q1_chunks],
q2_chunks=Q2_chunks,
)
Expand Down

0 comments on commit e783c57

Please sign in to comment.