diff --git a/src/features.rs b/src/features.rs index 9b554e87..5e127bc2 100644 --- a/src/features.rs +++ b/src/features.rs @@ -5,9 +5,7 @@ use core::{cell::RefMut, fmt::Debug, ops::RangeBounds}; use crate::{ - op_hint::OpHint, - range::{AsRange, CursorRange}, - AnyOp, HasId, Parents, Shape, UniqueId, Unit, WrappedData, ZeroGrad, CPU, + location, op_hint::OpHint, range::{AsRange, CursorRange}, AnyOp, HasId, Parents, Shape, UniqueId, Unit, WrappedData, ZeroGrad, CPU }; #[cfg(feature = "cached")] @@ -622,6 +620,18 @@ pub trait UseGpuOrCpu { gpu_op: impl FnMut(), ) -> GpuOrCpuInfo; + #[track_caller] + #[inline] + fn use_cpu_or_gpu_tracked( + &self, + input_lengths: &[usize], + cpu_op: impl FnMut(), + gpu_op: impl FnMut(), + ) -> GpuOrCpuInfo { + let location = crate::HashLocation::from(core::panic::Location::caller()); + self.use_cpu_or_gpu(location, input_lengths, cpu_op, gpu_op) + } + fn set_fork_enabled(&self, _enabled: bool); #[inline]