diff --git a/example-hello/src/bin/hello_client.rs b/example-hello/src/bin/hello_client.rs index 2999537..c8b851b 100644 --- a/example-hello/src/bin/hello_client.rs +++ b/example-hello/src/bin/hello_client.rs @@ -50,7 +50,8 @@ fn main() -> std::result::Result<(), Box> { let hello: rsbinder::Strong = hub::get_interface(SERVICE_NAME) .unwrap_or_else(|_| panic!("Can't find {SERVICE_NAME}")); - hello.as_binder().link_to_death(Arc::new(MyDeathRecipient{}))?; + let recipient = Arc::new(MyDeathRecipient{}); + hello.as_binder().link_to_death(Arc::downgrade(&(recipient as Arc)))?; // Call echo method of Hello proxy. let echo = hello.echo("Hello World!")?; diff --git a/rsbinder-tools/src/bin/rsb_hub.rs b/rsbinder-tools/src/bin/rsb_hub.rs index c10a7dc..ed77b33 100644 --- a/rsbinder-tools/src/bin/rsb_hub.rs +++ b/rsbinder-tools/src/bin/rsb_hub.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 #![allow(non_snake_case)] -use std::{collections::HashMap, sync::{RwLock, Arc}}; +use std::{collections::HashMap, sync::{mpsc, Arc, Mutex}}; use hub::{IServiceManager, BnServiceManager, DUMP_FLAG_PRIORITY_DEFAULT}; use env_logger::Env; use clap; @@ -12,9 +12,10 @@ struct Service { binder: SIBinder, _allow_isolated: bool, dump_priority: i32, - _has_clients: bool, + has_clients: bool, guarentee_client: bool, _debug_pid: u32, + context: rsbinder::thread_state::CallingContext, } impl Service { @@ -27,103 +28,208 @@ impl Service { } } -struct ServiceManagerInner { - name_to_service: RwLock>, - name_to_registration_callbacks: RwLock>>>, +struct DeathRecipientWrapper(mpsc::Sender); + +impl rsbinder::DeathRecipient for DeathRecipientWrapper { + fn binder_died(&self, who: &rsbinder::WIBinder) { + self.0.send(who.clone()) + .unwrap_or_else(|e| { + log::error!("Failed to send death notification: {:?}", e); + }); + } } -impl ServiceManagerInner { - fn try_get_service(&self, name: &str, _start_if_not_found: bool) -> rsbinder::status::Result> { - self.name_to_service.write().unwrap().get_mut(name).map(|service| { - service.guarentee_client = true; - Ok(Some(service.binder.clone())) - }).unwrap_or_else(|| { - Ok(None) - }) +struct Inner { + death_recipient: Arc, + name_to_service: HashMap, + name_to_registration_callbacks: HashMap>>, + name_to_client_callbacks: HashMap>>, +} + +impl Inner { + fn new(death_sender: mpsc::Sender) -> Self { + Self { + death_recipient: Arc::new(DeathRecipientWrapper(death_sender)), + name_to_service: HashMap::new(), + name_to_registration_callbacks: HashMap::new(), + name_to_client_callbacks: HashMap::new(), + } } - fn add_service(&self, name: &str, service: Service) -> rsbinder::status::Result<()> { - self.name_to_service.write().unwrap().insert(name.to_owned(), service); + fn add_service(&mut self, name: &str, service: Service) -> rsbinder::status::Result<()> { + self.name_to_service.insert(name.to_owned(), service); Ok(()) } - fn on_registration(&self, name: &str) -> rsbinder::status::Result<()> { - if let Some(service) = self.name_to_service.read().unwrap().get(name) { - let callbacks = self.name_to_registration_callbacks.read().unwrap().get(name).cloned(); - if let Some(callbacks) = callbacks { - for callback in callbacks { - callback.onRegistration(name, &service.binder)?; - } - } + fn send_client_callback_notification(&mut self, service_name: &str, has_clients: bool, context: &str) { + let service = if let Some(service) = self.name_to_service.get_mut(service_name) { + service + } else { + log::warn!("send_client_callback_notification could not find service {} when {}", service_name, context); + return; + }; + + if service.has_clients == has_clients { + log::error!("send_client_callback_notification called with the same state {} when {}", + has_clients, context); + std::process::abort() } - Ok(()) - } - fn list_services(&self, dump_priority: i32) -> rsbinder::status::Result> { - let mut services = Vec::new(); + log::info!("Notifying {} they {} (previously: {}) have clients when {}", + service_name, if has_clients { "do" } else { "don't" }, + if service.has_clients { "do" } else { "don't" }, context); - for (name, service) in self.name_to_service.read().unwrap().iter() { - if (service.dump_priority & dump_priority) != 0 { - services.push(name.clone()); + self.name_to_client_callbacks.get(service_name).map(|callbacks| { + for callback in callbacks { + callback.onClients(&service.binder, has_clients) + .unwrap_or_else(|e| { + log::error!("Failed to notify client callback: {:?}", e); + }); } - } + }).unwrap_or_else(|| { + log::warn!("send_client_callback_notification could not find callbacks for service when {}", context); + }); - Ok(services) + service.has_clients = has_clients; } - fn register_for_notifications(&self, name: &str, callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { - let mut callbacks = self.name_to_registration_callbacks.write().unwrap(); - let callbacks = callbacks.entry(name.to_owned()).or_default(); - callbacks.push(callback.clone()); + fn handle_service_client_callback(&mut self, known_clients: usize, service_name: &str, is_called_on_interval: bool) -> Result { + let service = if let Some(service) = self.name_to_service.get(service_name) { + if self.name_to_client_callbacks.get(service_name) + .map_or(true, |callbacks| callbacks.is_empty()) { + return Ok(true); + } + service + } else { + return Ok(true); + }; + + let count = match rsbinder::ProcessState::as_self() + .strong_ref_count_for_node(service.binder.as_proxy().expect("Service must be a proxy")) { + Ok(count) => count, + Err(e) => { + log::error!("Failed to get strong ref count for {}: {:?}", service_name, e); + return Ok(true); + } + }; + let has_kernel_reported_clients = count > known_clients; - if let Some(service) = self.name_to_service.read().unwrap().get(name) { - callback.onRegistration(name, &service.binder)?; - } + // To avoid the borrow checker, we need to get the value of has_clients + let mut has_clients = service.has_clients; - Ok(()) - } + if service.guarentee_client { + if !has_clients && !has_kernel_reported_clients { + self.send_client_callback_notification(service_name, true, + "service is guaranteed to be in use"); + } - fn unregister_for_notifications(&self, name: &str, callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { - let mut callbacks = self.name_to_registration_callbacks.write().unwrap(); - if let Some(callbacks) = callbacks.get_mut(name) { - callbacks.retain(|c| c.as_binder() != callback.as_binder()); - Ok(()) - } else { - log::error!("Trying to unregister callback, but none exists {}", name); - Err(ExceptionCode::IllegalState.into()) + self.name_to_service.get_mut(service_name).map(|service| { + // guarantee is temporary + service.has_clients = true; + has_clients = true; + }); } - } -} -impl Default for ServiceManagerInner { - fn default() -> Self { - Self { - name_to_service: RwLock::new(HashMap::new()), - name_to_registration_callbacks: RwLock::new(HashMap::new()), + if has_kernel_reported_clients && !has_clients { + self.send_client_callback_notification(service_name, true, + "we now have a record of a client"); + self.name_to_service.get(service_name).map(|service| { + has_clients = service.has_clients; + }); } + + if is_called_on_interval { + if !has_kernel_reported_clients && has_clients{ + self.send_client_callback_notification(service_name, false, + "we now have no record of a client"); + self.name_to_service.get(service_name).map(|service| { + has_clients = service.has_clients; + }); + } + } + + Ok(has_clients) } -} -impl rsbinder::DeathRecipient for ServiceManagerInner { - fn binder_died(&self, who: &rsbinder::WIBinder) { - self.name_to_service.write().unwrap().retain(|_, service| { - !(SIBinder::downgrade(&service.binder) == *who) + + fn try_get_binder(&mut self, name:&str, _start_if_not_found: bool) -> rsbinder::status::Result> { + let service = if let Some(service) = self.name_to_service.get_mut(name) { + service + } else { + return Ok(None); + }; + + let out = service.binder.clone(); + service.guarentee_client = true; + self.handle_service_client_callback(2 /* sm + transaction */, name, false)?; + + self.name_to_service.get_mut(name).map(|service| { + service.guarentee_client = true; }); - self.name_to_registration_callbacks.write().unwrap().retain(|_, callbacks| { - callbacks.retain(|callback| { - SIBinder::downgrade(&callback.as_binder()) != *who + Ok(Some(out)) + } + + fn remove_registration_callback(&mut self, name: Option<&str>, who: &rsbinder::WIBinder) -> bool { + let mut found = false; + if let Some(name) = name { + if let Some(callbacks) = self.name_to_registration_callbacks.get_mut(name) { + callbacks.retain(|callback| { + let is_not_equal = SIBinder::downgrade(&callback.as_binder()) != *who; + found |= !is_not_equal; + is_not_equal + }); + if callbacks.is_empty() { + self.name_to_registration_callbacks.remove(name); + } + } + } else { + self.name_to_registration_callbacks.retain(|_, callbacks| { + callbacks.retain(|callback| { + let is_not_equal = SIBinder::downgrade(&callback.as_binder()) != *who; + found |= !is_not_equal; + is_not_equal + }); + !callbacks.is_empty() }); - !callbacks.is_empty() - }); + } + + found } } struct ServiceManager { - inner: Arc, + inner: Arc>, } impl ServiceManager { + fn new() -> Self { + let (death_sender, death_receiver) = mpsc::channel(); + + let this = Self { + inner: Arc::new(Mutex::new(Inner::new(death_sender))), + }; + + this.run_death_receiver(death_receiver); + + this + } + + fn run_death_receiver(&self, death_receiver: mpsc::Receiver) { + let inner_clone = Arc::clone(&self.inner); + std::thread::spawn(move || { + for who in death_receiver { + let mut inner = inner_clone.lock().unwrap(); + + inner.name_to_service.retain(|_, service| { + !(SIBinder::downgrade(&service.binder) == who) + }); + + inner.remove_registration_callback(None, &who); + } + }); + } + fn is_valid_service_name(name: &str) -> bool { if name.is_empty() || name.len() > 127 { return false; @@ -148,19 +254,11 @@ impl ServiceManager { } } -impl Default for ServiceManager { - fn default() -> Self { - Self { - inner: Arc::new(ServiceManagerInner::default()), - } - } -} - impl Interface for ServiceManager {} impl IServiceManager for ServiceManager { - fn getService(&self,_arg_name: &str) -> rsbinder::status::Result> { - self.inner.try_get_service(_arg_name, true) + fn getService(&self, name: &str) -> rsbinder::status::Result> { + self.inner.lock().unwrap().try_get_binder(name, false) } fn addService(&self, name: &str, service: &SIBinder, allowIsolated: bool, dumpPriority: i32) -> rsbinder::status::Result<()> { @@ -168,77 +266,233 @@ impl IServiceManager for ServiceManager { return Err(ExceptionCode::IllegalArgument.into()); } + let mut inner = self.inner.lock().unwrap(); + + // Only if the service is a proxy, link to death. + // Because the native service does not support death notification. if service.as_proxy().is_some() { - service.link_to_death(self.inner.clone())?; + service.link_to_death(Arc::downgrade(&(inner.death_recipient.clone() as Arc)))?; + } + + let mut prev_clients = false; + { + inner.name_to_service.get(name).map(|service| { + prev_clients = service.has_clients; + }); } - self.inner.add_service(name, Service { + + inner.add_service(name, Service { binder: service.clone(), _allow_isolated: allowIsolated, dump_priority: dumpPriority, - _has_clients: false, + has_clients: prev_clients, guarentee_client: false, _debug_pid: 0, + context: rsbinder::thread_state::CallingContext::new(), })?; - self.inner.on_registration(name)?; + if inner.name_to_registration_callbacks.contains_key(name) { + inner.name_to_service.get_mut(name).map(|service| { + service.guarentee_client = true; + }); + + inner.handle_service_client_callback(2, name, false)?; + + inner.name_to_service.get_mut(name).map(|service| { + service.guarentee_client = true; + }); + + let callbacks = inner.name_to_registration_callbacks.get(name) + .expect("name_to_registration_callbacks must have key"); + for callback in callbacks { + callback.onRegistration(name, service)?; + } + } Ok(()) } fn checkService(&self, name: &str) -> rsbinder::status::Result> { - self.inner.try_get_service(name, false) + self.inner.lock().unwrap().try_get_binder(name, false) } - fn listServices(&self, dumpPriority: i32) -> rsbinder::status::Result> { - self.inner.list_services(dumpPriority) + fn listServices(&self, dump_priority: i32) -> rsbinder::status::Result> { + let inner = self.inner.lock().unwrap(); + + let mut services = Vec::new(); + + for (name, service) in inner.name_to_service.iter() { + if (service.dump_priority & dump_priority) != 0 { + services.push(name.clone()); + } + } + + Ok(services) } - fn registerForNotifications(&self, _arg_name: &str, _arg_callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { - if !Self::is_valid_service_name(_arg_name) { + fn registerForNotifications(&self, name: &str, arg_callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { + if !Self::is_valid_service_name(name) { return Err(ExceptionCode::IllegalArgument.into()); } - _arg_callback.as_binder().link_to_death(self.inner.clone())?; + let mut inner = self.inner.lock().unwrap(); + + arg_callback.as_binder().link_to_death(Arc::downgrade(&(inner.death_recipient.clone() as Arc)))?; + + inner.name_to_registration_callbacks + .entry(name.to_string()) + .or_insert_with(Vec::new) + .push(arg_callback.clone()); - self.inner.register_for_notifications(_arg_name, _arg_callback) + inner.name_to_service.get(name).map(|service| { + arg_callback.onRegistration(name, &service.binder) + .unwrap_or_else(|e| { + log::error!("Failed to notify client callback: {:?}", e); + }); + }); + + Ok(()) } - fn unregisterForNotifications(&self, _arg_name: &str, _arg_callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { - self.inner.unregister_for_notifications(_arg_name, _arg_callback) + fn unregisterForNotifications(&self, name: &str, callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { + let mut inner = self.inner.lock().unwrap(); + + if inner.remove_registration_callback(Some(name), + &SIBinder::downgrade(&callback.as_binder())) { + Ok(()) + } else { + Err(ExceptionCode::IllegalState.into()) + } } fn isDeclared(&self,_arg_name: &str) -> rsbinder::status::Result { + // TODO: Implement this + log::warn!("isDeclared is not implemented"); Ok(false) } fn getDeclaredInstances(&self,_arg_iface: &str) -> rsbinder::status::Result> { - println!("getDeclaredInstances"); + log::warn!("getDeclaredInstances is not implemented"); Ok(vec![]) } fn updatableViaApex(&self,_arg_name: &str) -> rsbinder::status::Result> { - println!("updatableViaApex"); + log::warn!("updatableViaApex is not implemented"); Ok(None) } fn getConnectionInfo(&self,_arg_name: &str) -> rsbinder::status::Result> { - println!("getConnectionInfo"); + log::warn!("getConnectionInfo is not implemented"); Ok(None) } - fn registerClientCallback(&self,_arg_name: &str,_arg_service: &rsbinder::SIBinder,_arg_callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { - println!("registerClientCallback"); + fn registerClientCallback(&self, name: &str, arg_service: &rsbinder::SIBinder, arg_callback: &rsbinder::Strong) -> rsbinder::status::Result<()> { + let mut inner = self.inner.lock().unwrap(); + + let service = if let Some(service) = inner.name_to_service.get(name) { + service + } else { + let msg = format!("registerClientCallback could not find service {}", name); + log::warn!("{}", &msg); + return Err((ExceptionCode::IllegalArgument, msg.as_str()).into()); + }; + + if service.context.pid != rsbinder::thread_state::CallingContext::new().pid { + let msg = format!("{:?} Only a server can register for client callbacks (for {})", + service.context, name); + log::warn!("{}", &msg); + return Err((ExceptionCode::Security, msg.as_str()).into()); + } + + if service.binder != *arg_service { + let msg = format!("registerClientCallback called with wrong service {}", name); + log::warn!("{}", &msg); + return Err((ExceptionCode::IllegalArgument, msg.as_str()).into()); + } + + arg_callback.as_binder().link_to_death(Arc::downgrade(&(inner.death_recipient.clone() as Arc)))?; + + if service.has_clients { + arg_callback.onClients(&service.binder, true) + .unwrap_or_else(|e| { + log::error!("Failed to notify client callback: {:?}", e); + }); + } + + inner.name_to_client_callbacks + .entry(name.to_string()) + .or_insert_with(Vec::new) + .push(arg_callback.clone()); + + inner.handle_service_client_callback(2 /* sm + transaction */, + name, false)?; + Ok(()) } - fn tryUnregisterService(&self,_arg_name: &str,_arg_service: &rsbinder::SIBinder) -> rsbinder::status::Result<()> { - println!("tryUnregisterService"); + fn tryUnregisterService(&self, name: &str, arg_service: &rsbinder::SIBinder) -> rsbinder::status::Result<()> { + let context = rsbinder::thread_state::CallingContext::new(); + + let mut inner = self.inner.lock().unwrap(); + let service = if let Some(service) = inner.name_to_service.get(name) { + service + } else { + let msg = format!("{:?} Tried to unregister {}, but that service wasn't registered to begin with.", + context, name); + log::warn!("{}", &msg); + return Err((ExceptionCode::IllegalArgument, msg.as_str()).into()); + }; + + if service.context.pid != rsbinder::thread_state::CallingContext::new().pid { + let msg = format!("{:?} Only a server can register for client callbacks (for {})", + service.context, name); + log::warn!("{}", &msg); + return Err((ExceptionCode::Security, msg.as_str()).into()); + } + + if arg_service.clone() != service.binder.clone() { + let msg = format!("{:?} Tried to unregister {}, but a different service is registered under this name.", + context, name); + log::warn!("{}", &msg); + return Err((ExceptionCode::IllegalArgument, msg.as_str()).into()); + } + + if service.guarentee_client { + let msg = format!("{:?} Tried to unregister {}, but there is about to be a client.", + context, name); + log::warn!("{}", &msg); + return Err((ExceptionCode::IllegalState, msg.as_str()).into()); + } + + let res = inner.handle_service_client_callback(2, name, false); + if res.is_err() { + let msg = format!("{:?} Tried to unregister {}, but there are clients.", + context, name); + log::warn!("{}", &msg); + inner.name_to_service.get_mut(name).map(|service| { + service.guarentee_client = true; + }); + return Err((ExceptionCode::IllegalState, msg.as_str()).into()); + } + + inner.name_to_service.remove(name); + Ok(()) } fn getServiceDebugInfo(&self) -> rsbinder::status::Result> { - println!("getServiceDebugInfo"); - Ok(vec![]) + let inner = self.inner.lock().unwrap(); + + let mut out = Vec::with_capacity(inner.name_to_service.len()); + + for (name, service) in inner.name_to_service.iter() { + out.push(hub::android::os::ServiceDebugInfo::ServiceDebugInfo { + name: name.clone(), + debugPid: service.context.pid, + }); + } + + Ok(out) } } @@ -251,10 +505,10 @@ fn main() -> std::result::Result<(), Box> { env_logger::Builder::from_env(Env::default().default_filter_or("warn")).init(); - ProcessState::init(DEFAULT_BINDER_PATH, 1); + ProcessState::init(DEFAULT_BINDER_PATH, 0); // Create a binder service. - let service = BnServiceManager::new_binder(ServiceManager::default()); + let service = BnServiceManager::new_binder(ServiceManager::new()); service.addService("manager", &service.as_binder(), false, DUMP_FLAG_PRIORITY_DEFAULT)?; ProcessState::as_self().become_context_manager(service.as_binder())?; diff --git a/rsbinder/build.rs b/rsbinder/build.rs index 64f0886..c4a2320 100644 --- a/rsbinder/build.rs +++ b/rsbinder/build.rs @@ -17,4 +17,11 @@ fn main() { .set_crate_support(true) .generate().unwrap(); + + rsbinder_aidl::Builder::new() + .source(PathBuf::from("aidl/android/sm/tests/IFoo.aidl")) + + .output(PathBuf::from("sm_tests_aidl.rs")) + .set_crate_support(true) + .generate().unwrap(); } \ No newline at end of file diff --git a/rsbinder/src/binder.rs b/rsbinder/src/binder.rs index 2476e5a..d851bb4 100644 --- a/rsbinder/src/binder.rs +++ b/rsbinder/src/binder.rs @@ -19,7 +19,7 @@ use std::mem::ManuallyDrop; use std::ops::Deref; -use std::sync::Arc; +use std::sync::{self, Arc}; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; @@ -153,12 +153,12 @@ pub trait IBinder: Any + Send + Sync { /// INVALID_OPERATION code being returned and nothing happening. /// /// This link always holds a weak reference to its recipient. - fn link_to_death(&self, recipient: Arc) -> Result<()>; + fn link_to_death(&self, recipient: sync::Weak) -> Result<()>; /// Remove a previously registered death notification. /// The recipient will no longer be called if this object /// dies. - fn unlink_to_death(&self, recipient: Arc) -> Result<()>; + fn unlink_to_death(&self, recipient: sync::Weak) -> Result<()>; /// Send a ping transaction to this object fn ping_binder(&self) -> Result<()>; @@ -169,7 +169,9 @@ pub trait IBinder: Any + Send + Sync { /// To convert the interface to a transactable object fn as_transactable(&self) -> Option<&dyn Transactable>; + /// Retrieve the descriptor of this object. fn descriptor(&self) -> &str; + /// Retrieve if this object is remote. fn is_remote(&self) -> bool; fn inc_strong(&self, strong: &SIBinder) -> Result<()>; diff --git a/rsbinder/src/hub/servicemanager.rs b/rsbinder/src/hub/servicemanager.rs index 15371bb..4464e65 100644 --- a/rsbinder/src/hub/servicemanager.rs +++ b/rsbinder/src/hub/servicemanager.rs @@ -30,7 +30,7 @@ pub fn default() -> Arc { unsafe { INIT.call_once(|| { let process = ProcessState::as_self(); - let service_manager = process.context_object().unwrap(); + let service_manager = process.context_object().expect("Failed to get context object"); let service_manager = BpServiceManager::from_binder(service_manager).unwrap(); GLOBAL_SM = Some(Arc::new(service_manager)); // Replace 0 with your initial value IS_INIT.store(true, Ordering::SeqCst); @@ -86,12 +86,14 @@ pub fn add_service(identifier: &str, binder: SIBinder) -> std::result::Result<() /// Request a callback when a service is registered. pub fn register_for_notifications(name: &str, callback: &crate::Strong) -> Result<()> { - default().registerForNotifications(name, callback).map_err(|e| e.into()) + default().registerForNotifications(name, callback) + .map_err(|e| e.into()) } /// Unregisters all requests for notifications for a specific callback. pub fn unregister_for_notifications(name: &str, callback: &crate::Strong) -> Result<()> { - default().unregisterForNotifications(name, callback).map_err(|e| e.into()) + default().unregisterForNotifications(name, callback) + .map_err(|e| e.into()) } /// Returns whether a given interface is declared on the device, even if it @@ -114,78 +116,3 @@ pub fn get_interface(name: &str) -> Result> { None => Err(StatusCode::NameNotFound), } } - -#[cfg(test)] -mod tests { - #![allow(non_snake_case)] - - use super::*; - use std::sync::OnceLock; - - fn setup() { - static INIT: OnceLock = OnceLock::new(); - - let _ = INIT.get_or_init(|| { - env_logger::init(); - crate::ProcessState::init(crate::DEFAULT_BINDER_PATH, 0); - true - }); - } - - #[test] - fn test_get_check_list_service() -> crate::Result<()> { - setup(); - - #[cfg(target_os = "android")] - { - let manager_name = "manager"; - let binder = get_service(manager_name); - assert!(binder.is_some()); - - let binder = check_service(manager_name); - assert!(binder.is_some()); - } - - let unknown_name = "unknown_service"; - let binder = get_service(unknown_name); - assert!(binder.is_none()); - let binder = check_service(unknown_name); - assert!(binder.is_none()); - - let services = list_services(DUMP_FLAG_PRIORITY_DEFAULT); - assert!(!services.is_empty()); - - Ok(()) - } - - #[test] - fn test_notifications() -> crate::Result<()> { - setup(); - - struct MyServiceCallback {} - impl crate::Interface for MyServiceCallback {} - impl IServiceCallback for MyServiceCallback { - fn onRegistration(&self, name: &str, service: &crate::SIBinder) -> crate::status::Result<()> { - println!("onRegistration: {} {:?}", name, service); - Ok(()) - } - } - - let callback = BnServiceCallback::new_binder(MyServiceCallback{}); - - register_for_notifications("mytest_service", &callback)?; - - unregister_for_notifications("mytest_service", &callback)?; - - Ok(()) - } - - #[test] - fn test_others() -> crate::Result<()> { - setup(); - - assert!(!is_declared("android.hardware.usb.IUsb/default")); - - Ok(()) - } -} diff --git a/rsbinder/src/native.rs b/rsbinder/src/native.rs index 903c08d..1b30e98 100644 --- a/rsbinder/src/native.rs +++ b/rsbinder/src/native.rs @@ -17,7 +17,7 @@ * limitations under the License. */ -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::ops::{Deref, DerefMut}; use std::any::Any; use std::convert::TryFrom; @@ -81,7 +81,7 @@ impl Inner { } impl IBinder for Inner { - fn link_to_death(&self, _recipient: Arc) -> Result<()> { + fn link_to_death(&self, _recipient: Weak) -> Result<()> { log::error!("Binder does not support link_to_death."); Err(StatusCode::InvalidOperation) } @@ -89,7 +89,7 @@ impl IBinder for Inner { /// Remove a previously registered death notification. /// The recipient will no longer be called if this object /// dies. - fn unlink_to_death(&self, _recipient: Arc) -> Result<()> { + fn unlink_to_death(&self, _recipient: Weak) -> Result<()> { log::error!("Binder does not support unlink_to_death."); Err(StatusCode::InvalidOperation) } diff --git a/rsbinder/src/process_state.rs b/rsbinder/src/process_state.rs index 9a5ba3c..703ef0f 100644 --- a/rsbinder/src/process_state.rs +++ b/rsbinder/src/process_state.rs @@ -268,6 +268,24 @@ impl ProcessState { // to return too high of a value. } + pub fn strong_ref_count_for_node(&self, node: &ProxyHandle) -> Result { + let mut info = binder::binder_node_info_for_ref { + handle: node.handle(), + strong_count: 0, + weak_count: 0, + reserved1: 0, + reserved2: 0, + reserved3: 0, + }; + + binder::get_node_info_for_ref(&self.driver, &mut info) + .map_err(|e| { + log::error!("Binder ioctl(BINDER_GET_NODE_INFO_FOR_REF) failed: {:?}", e); + e + })?; + Ok(info.strong_count as usize) + } + pub fn join_thread_pool() -> Result<()> { thread_state::join_thread_pool(true) } diff --git a/rsbinder/src/proxy.rs b/rsbinder/src/proxy.rs index 80069bb..8ca0c6b 100644 --- a/rsbinder/src/proxy.rs +++ b/rsbinder/src/proxy.rs @@ -6,7 +6,7 @@ use std::fmt::{Debug, Formatter}; use std::mem::ManuallyDrop; use std::os::fd::IntoRawFd; use std::sync::atomic::AtomicBool; -use std::sync::{Arc, RwLock}; +use std::sync::{self, Arc, RwLock}; use crate::{ parcel::*, @@ -22,7 +22,7 @@ pub struct ProxyHandle { descriptor: String, stability: Stability, obituary_sent: AtomicBool, - recipients: RwLock>>, + recipients: RwLock>>, strong: RefCounter, weak: RefCounter, @@ -72,8 +72,24 @@ impl ProxyHandle { thread_state::flush_commands()?; } + // To remember the recipients to remove + let mut recipients_to_remove = Vec::new(); for recipient in recipients.iter() { - recipient.binder_died(who); + if let Some(recipient) = recipient.upgrade() { + recipient.binder_died(who); + } else { + // The recipient is already dead + recipients_to_remove.push(recipient.clone()); + } + } + + drop(recipients); // Release the read lock before acquiring the write lock + + if !recipients_to_remove.is_empty() { + let mut recipients = self.recipients.write().unwrap(); + for recipient in recipients_to_remove { + recipients.retain(|r| !sync::Weak::ptr_eq(r, &recipient)); + } } Ok(()) @@ -112,7 +128,7 @@ impl PartialEq for ProxyHandle { impl IBinder for ProxyHandle { /// Register a death notification for this object. - fn link_to_death(&self, recipient: Arc) -> Result<()> { + fn link_to_death(&self, recipient: sync::Weak) -> Result<()> { if self.obituary_sent.load(std::sync::atomic::Ordering::Relaxed) { return Err(StatusCode::DeadObject); } else { @@ -130,13 +146,13 @@ impl IBinder for ProxyHandle { /// Remove a previously registered death notification. /// The recipient will no longer be called if this object /// dies. - fn unlink_to_death(&self, recipient: Arc) -> Result<()> { + fn unlink_to_death(&self, recipient: sync::Weak) -> Result<()> { if self.obituary_sent.load(std::sync::atomic::Ordering::Relaxed) { return Err(StatusCode::DeadObject); } else { let mut recipients = self.recipients.write().unwrap(); - recipients.retain(|r| !Arc::ptr_eq(r, &recipient)); + recipients.retain(|r| !sync::Weak::ptr_eq(r, &recipient)); if recipients.is_empty() { thread_state::clear_death_notification(self.handle())?; thread_state::flush_commands()?; diff --git a/rsbinder/src/sys/mod.rs b/rsbinder/src/sys/mod.rs index f597b18..8c62899 100644 --- a/rsbinder/src/sys/mod.rs +++ b/rsbinder/src/sys/mod.rs @@ -64,6 +64,7 @@ pub mod binder { // nix::ioctl_readwrite!(write_read, b'b', 1, binder_write_read); pub(crate) fn write_read(fd: Fd, write_read: &mut binder_write_read) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_WRITE_READ let ctl = ioctl::Updater::, binder_write_read>::new(write_read); ioctl::ioctl(fd, ctl) } @@ -72,6 +73,7 @@ pub mod binder { // nix::ioctl_write_ptr!(set_max_threads, b'b', 5, __u32); pub(crate) fn set_max_threads(fd: Fd, max_threads: u32) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_SET_MAX_THREADS let ctl = ioctl::Setter::, _>::new(max_threads); ioctl::ioctl(fd, ctl) } @@ -80,6 +82,7 @@ pub mod binder { // nix::ioctl_write_ptr!(set_context_mgr, b'b', 7, __s32); pub(crate) fn set_context_mgr(fd: Fd, pid: i32) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_SET_CONTEXT_MGR let ctl = ioctl::Setter::, _>::new(pid); ioctl::ioctl(fd, ctl) } @@ -88,6 +91,7 @@ pub mod binder { // nix::ioctl_readwrite!(version, b'b', 9, binder_version); pub(crate) fn version(fd: Fd, ver: &mut binder_version) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_VERSION let ctl = ioctl::Updater::, binder_version>::new(ver); ioctl::ioctl(fd, ctl) } @@ -96,6 +100,7 @@ pub mod binder { // nix::ioctl_write_ptr!(set_context_mgr_ext, b'b', 13, flat_binder_object); pub(crate) fn set_context_mgr_ext(fd: Fd, obj: flat_binder_object) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_SET_CONTEXT_MGR_EXT let ctl = ioctl::Setter::, _>::new(obj); ioctl::ioctl(fd, ctl) } @@ -104,6 +109,7 @@ pub mod binder { // nix::ioctl_write_ptr!(enable_oneway_spam_detection, b'b', 16, __u32); pub(crate) fn enable_oneway_spam_detection(fd: Fd, enable: __u32) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_ENABLE_ONEWAY_SPAM_DETECTION let ctl = ioctl::Setter::, _>::new(enable); ioctl::ioctl(fd, ctl) } @@ -112,6 +118,7 @@ pub mod binder { // nix::ioctl_readwrite!(binder_ctl_add, b'b', 1, binderfs_device); pub(crate) fn binder_ctl_add(fd: Fd, device: &mut binderfs_device) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_CTL_ADD let ctl = ioctl::Updater::, _>::new(device); ioctl::ioctl(fd, ctl) } @@ -120,6 +127,7 @@ pub mod binder { // nix::ioctl_write_ptr!(set_idle_timeout, b'b', 3, __s64); pub(crate) fn set_idle_timeout(fd: Fd, timeout: i64) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_SET_IDLE_TIMEOUT let ctl = ioctl::Setter::, _>::new(timeout); ioctl::ioctl(fd, ctl) } @@ -128,6 +136,7 @@ pub mod binder { // nix::ioctl_write_ptr!(set_idle_priority, b'b', 6, __s32); pub(crate) fn set_idle_priority(fd: Fd, priority: i32) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_SET_IDLE_PRIORITY let ctl = ioctl::Setter::, _>::new(priority); ioctl::ioctl(fd, ctl) } @@ -136,6 +145,7 @@ pub mod binder { // nix::ioctl_write_ptr!(thread_exit, b'b', 8, __s32); pub(crate) fn thread_exit(fd: Fd, pid: i32) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_THREAD_EXIT let ctl = ioctl::Setter::, _>::new(pid); ioctl::ioctl(fd, ctl) } @@ -144,6 +154,7 @@ pub mod binder { // nix::ioctl_readwrite!(get_node_debug_info, b'b', 11, binder_node_debug_info); pub(crate) fn get_node_debug_info(fd: Fd, node_debug_info: &mut binder_node_debug_info) -> std::result::Result<(), rustix::io::Errno> { unsafe { + // BINDER_GET_NODE_DEBUG_INFO let ctl = ioctl::Updater::, _>::new(node_debug_info); ioctl::ioctl(fd, ctl) } @@ -152,6 +163,7 @@ pub mod binder { // nix::ioctl_readwrite!(get_node_info_for_ref, b'b', 12, binder_node_info_for_ref); pub(crate) fn get_node_info_for_ref(fd: Fd, node_info: &mut binder_node_info_for_ref) -> std::result::Result<(), rustix::io::Errno> { unsafe { + // BINDER_GET_NODE_INFO_FOR_REF let ctl = ioctl::Updater::, _>::new(node_info); ioctl::ioctl(fd, ctl) } @@ -160,6 +172,7 @@ pub mod binder { // nix::ioctl_write_ptr!(freeze, b'b', 14, binder_freeze_info); pub(crate) fn freeze(fd: Fd, info: binder_freeze_info) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_FREEZE let ctl = ioctl::Setter::, _>::new(info); ioctl::ioctl(fd, ctl) } @@ -168,6 +181,7 @@ pub mod binder { // nix::ioctl_readwrite!(get_frozen_info, b'b', 15, binder_frozen_status_info); pub(crate) fn get_frozen_info(fd: Fd, frozen_info: &mut binder_frozen_status_info) -> std::result::Result<(), io::Errno> { unsafe { + // BINDER_GET_FROZEN_INFO let ctl = ioctl::Updater::, _>::new(frozen_info); ioctl::ioctl(fd, ctl) } diff --git a/rsbinder/src/thread_state.rs b/rsbinder/src/thread_state.rs index 7cd4a29..f51f9fa 100644 --- a/rsbinder/src/thread_state.rs +++ b/rsbinder/src/thread_state.rs @@ -17,6 +17,8 @@ * limitations under the License. */ +use std::ffi::{CString, CStr}; +use std::fmt::Debug; use std::sync::{atomic::Ordering, Arc}; use std::cell::RefCell; use log::error; @@ -109,8 +111,8 @@ pub(crate) const UNSET_WORK_SOURCE: i32 = -1; #[derive(Debug, Clone, Copy)] struct TransactionState { calling_pid: binder::pid_t, - _calling_sid: *const u8, - _calling_uid: binder::uid_t, + calling_sid: *const u8, + calling_uid: binder::uid_t, // strict_mode_policy: i32, last_transaction_binder_flags: u32, work_source: binder::uid_t, @@ -121,8 +123,8 @@ impl TransactionState { fn from_transaction_data(data: &binder::binder_transaction_data_secctx) -> Self { TransactionState { calling_pid: data.transaction_data.sender_pid, - _calling_sid: data.secctx as _, - _calling_uid: data.transaction_data.sender_euid, + calling_sid: data.secctx as _, + calling_uid: data.transaction_data.sender_euid, // strict_mode_policy: 0, last_transaction_binder_flags: data.transaction_data.flags, work_source: 0, @@ -1065,26 +1067,43 @@ pub(crate) fn clear_death_notification(handle: u32) -> Result<()> { }) } +#[derive(Debug)] pub struct CallingContext { pub pid: binder::pid_t, pub uid: binder::uid_t, - pub sid: *const u8, + pub sid: Option, } -pub(crate) fn _get_calling_context() -> Result { - THREAD_STATE.with(|thread_state| -> Result { - let thread_state = thread_state.borrow(); - let transaction = thread_state.transaction.as_ref().ok_or(StatusCode::Unknown)?; - let calling_pid = transaction.calling_pid; - let calling_uid = transaction._calling_uid; - let calling_sid = transaction._calling_sid; - - Ok(CallingContext { - pid: calling_pid, - uid: calling_uid, - sid: calling_sid, +impl CallingContext { + pub fn new() -> CallingContext { + THREAD_STATE.with(|thread_state| -> CallingContext { + let thread_state = thread_state.borrow(); + match thread_state.transaction.as_ref() { + Some(transaction) => { + let calling_sid = if transaction.calling_sid != std::ptr::null() { + unsafe { + Some(CStr::from_ptr(transaction.calling_sid as _).to_owned()) + } + } else { + None + }; + CallingContext { + pid: transaction.calling_pid, + uid: transaction.calling_uid, + sid: calling_sid, + } + } + None => { + log::debug!("CallingContext::new() called outside of transaction"); + CallingContext { + pid: rustix::process::getpid().as_raw_nonzero().get() as _, + uid: rustix::process::getuid().as_raw(), + sid: None, + } + } + } }) - }) + } } pub fn is_handling_transaction() -> bool { diff --git a/tests/aidl/android/aidl/tests/sm/IFoo.aidl b/tests/aidl/android/aidl/tests/sm/IFoo.aidl new file mode 100644 index 0000000..fc761fd --- /dev/null +++ b/tests/aidl/android/aidl/tests/sm/IFoo.aidl @@ -0,0 +1,5 @@ +package android.aidl.tests.sm; + +interface IFoo { + void hello(); +} diff --git a/tests/build.rs b/tests/build.rs index 849362f..c60a121 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -36,6 +36,7 @@ fn main() { .source(PathBuf::from("aidl/android/aidl/versioned/tests/BazUnion.aidl")) .source(PathBuf::from("aidl/android/aidl/versioned/tests/Foo.aidl")) .source(PathBuf::from("aidl/android/aidl/versioned/tests/IFooInterface.aidl")) + .source(PathBuf::from("aidl/android/aidl/tests/sm/IFoo.aidl")) .output(PathBuf::from("test_aidl.rs")) .generate().unwrap(); diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 5445f3d..2ce522f 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -1,4 +1,5 @@ // Copyright 2022 Jeff Kim // SPDX-License-Identifier: Apache-2.0 -mod test_client; \ No newline at end of file +mod test_client; +mod test_sm; \ No newline at end of file diff --git a/tests/src/test_sm.rs b/tests/src/test_sm.rs new file mode 100644 index 0000000..a6967e2 --- /dev/null +++ b/tests/src/test_sm.rs @@ -0,0 +1,116 @@ +#![allow(non_snake_case, dead_code, unused_imports, unused_macros)] + +use env_logger::Env; + +pub use rsbinder::*; + +include!(concat!(env!("OUT_DIR"), "/test_aidl.rs")); + +pub(crate) use android::aidl::tests::sm::IFoo::{IFoo, BnFoo, BpFoo}; + +pub(crate) struct IFooService { +} + +impl rsbinder::Interface for IFooService { +} + +impl IFoo for IFooService { + // Implement the echo method. + fn hello(&self) -> rsbinder::status::Result<()> { + Ok(()) + } +} + +use super::*; +use std::sync::OnceLock; + +fn setup() { + // static INIT: OnceLock = OnceLock::new(); + + // let _ = INIT.get_or_init(|| { + // env_logger::init(); + // rsbinder::ProcessState::init(rsbinder::DEFAULT_BINDER_PATH, 0); + // true + // }); +} + +#[test] +fn test_add_service() -> rsbinder::Result<()> { + setup(); + + let service = BnFoo::new_binder(IFooService{}); + assert!(hub::add_service("", service.as_binder()).is_err()); + + assert_eq!(hub::add_service("foo", service.as_binder()), Ok(())); + + // The maximum length of service name is 127. + let s = std::iter::repeat('a').take(127).collect::(); + assert!(hub::add_service(&s, service.as_binder()).is_ok()); + + let s = std::iter::repeat('a').take(128).collect::(); + assert!(hub::add_service(&s, service.as_binder()).is_err()); + + // Weird characters are not allowed. + assert!(hub::add_service("happy$foo$fo", service.as_binder()).is_err()); + + // Overwrite the service + assert_eq!(hub::add_service("foo", service.as_binder()), Ok(())); + + Ok(()) +} + +#[test] +fn test_get_check_list_service() -> rsbinder::Result<()> { + setup(); + + #[cfg(target_os = "android")] + { + let manager_name = "manager"; + let binder = get_service(manager_name); + assert!(binder.is_some()); + + let binder = check_service(manager_name); + assert!(binder.is_some()); + } + + let unknown_name = "unknown_service"; + let binder = hub::get_service(unknown_name); + assert!(binder.is_none()); + let binder = hub::check_service(unknown_name); + assert!(binder.is_none()); + + let services = hub::list_services(hub::DUMP_FLAG_PRIORITY_DEFAULT); + assert!(!services.is_empty()); + + Ok(()) +} + +#[test] +fn test_notifications() -> rsbinder::Result<()> { + setup(); + + struct MyServiceCallback {} + impl rsbinder::Interface for MyServiceCallback {} + impl hub::IServiceCallback for MyServiceCallback { + fn onRegistration(&self, name: &str, service: &rsbinder::SIBinder) -> rsbinder::status::Result<()> { + println!("onRegistration: {} {:?}", name, service); + Ok(()) + } + } + + let callback = hub::BnServiceCallback::new_binder(MyServiceCallback{}); + + hub::register_for_notifications("mytest_service", &callback)?; + hub::unregister_for_notifications("mytest_service", &callback)?; + + Ok(()) +} + +#[test] +fn test_others() -> rsbinder::Result<()> { + setup(); + + assert!(!hub::is_declared("android.hardware.usb.IUsb/default")); + + Ok(()) +}