use std; use std::ffi::OsString; use std::io::Error as IoError; use std::mem; use std::ops::{Deref, DerefMut}; use std::os::windows::ffi::OsStringExt; use std::ptr; use std::slice; use std::sync::{Arc, Mutex, MutexGuard}; use DefaultFormatError; use Format; use FormatsEnumerationError; use SampleFormat; use SampleRate; use SupportedFormat; use COMMON_SAMPLE_RATES; use super::check_result; use super::com; use super::winapi::Interface; use super::winapi::shared::devpkey; use super::winapi::shared::ksmedia; use super::winapi::shared::guiddef::{ GUID, }; use super::winapi::shared::winerror; use super::winapi::shared::minwindef::{ DWORD, }; use super::winapi::shared::mmreg; use super::winapi::shared::wtypes; use super::winapi::um::coml2api; use super::winapi::um::audioclient::{ IAudioClient, IID_IAudioClient, AUDCLNT_E_DEVICE_INVALIDATED, }; use super::winapi::um::audiosessiontypes::{ AUDCLNT_SHAREMODE_SHARED, }; use super::winapi::um::combaseapi::{ CoCreateInstance, CoTaskMemFree, CLSCTX_ALL, PropVariantClear, }; use super::winapi::um::mmdeviceapi::{ eAll, eCapture, eConsole, eRender, CLSID_MMDeviceEnumerator, DEVICE_STATE_ACTIVE, EDataFlow, IMMDevice, IMMDeviceCollection, IMMDeviceEnumerator, IMMEndpoint, }; pub type SupportedInputFormats = std::vec::IntoIter; pub type SupportedOutputFormats = std::vec::IntoIter; /// Wrapper because of that stupid decision to remove `Send` and `Sync` from raw pointers. #[derive(Copy, Clone)] struct IAudioClientWrapper(*mut IAudioClient); unsafe impl Send for IAudioClientWrapper { } unsafe impl Sync for IAudioClientWrapper { } /// An opaque type that identifies an end point. pub struct Device { device: *mut IMMDevice, /// We cache an uninitialized `IAudioClient` so that we can call functions from it without /// having to create/destroy audio clients all the time. future_audio_client: Arc>>, // TODO: add NonZero around the ptr } struct Endpoint { endpoint: *mut IMMEndpoint, } enum WaveFormat { Ex(mmreg::WAVEFORMATEX), Extensible(mmreg::WAVEFORMATEXTENSIBLE), } // Use RAII to make sure CoTaskMemFree is called when we are responsible for freeing. struct WaveFormatExPtr(*mut mmreg::WAVEFORMATEX); impl Drop for WaveFormatExPtr { fn drop(&mut self) { unsafe { CoTaskMemFree(self.0 as *mut _); } } } impl WaveFormat { // Given a pointer to some format, returns a valid copy of the format. pub fn copy_from_waveformatex_ptr(ptr: *const mmreg::WAVEFORMATEX) -> Option { unsafe { match (*ptr).wFormatTag { mmreg::WAVE_FORMAT_PCM | mmreg::WAVE_FORMAT_IEEE_FLOAT => { Some(WaveFormat::Ex(*ptr)) }, mmreg::WAVE_FORMAT_EXTENSIBLE => { let extensible_ptr = ptr as *const mmreg::WAVEFORMATEXTENSIBLE; Some(WaveFormat::Extensible(*extensible_ptr)) }, _ => None, } } } // Get the pointer to the WAVEFORMATEX struct. pub fn as_ptr(&self) -> *const mmreg::WAVEFORMATEX { self.deref() as *const _ } } impl Deref for WaveFormat { type Target = mmreg::WAVEFORMATEX; fn deref(&self) -> &Self::Target { match *self { WaveFormat::Ex(ref f) => f, WaveFormat::Extensible(ref f) => &f.Format, } } } impl DerefMut for WaveFormat { fn deref_mut(&mut self) -> &mut Self::Target { match *self { WaveFormat::Ex(ref mut f) => f, WaveFormat::Extensible(ref mut f) => &mut f.Format, } } } unsafe fn immendpoint_from_immdevice(device: *const IMMDevice) -> *mut IMMEndpoint { let mut endpoint: *mut IMMEndpoint = mem::uninitialized(); check_result((*device).QueryInterface(&IMMEndpoint::uuidof(), &mut endpoint as *mut _ as *mut _)) .expect("could not query IMMDevice interface for IMMEndpoint"); endpoint } unsafe fn data_flow_from_immendpoint(endpoint: *const IMMEndpoint) -> EDataFlow { let mut data_flow = mem::uninitialized(); check_result((*endpoint).GetDataFlow(&mut data_flow)) .expect("could not get endpoint data_flow"); data_flow } // Given the audio client and format, returns whether or not the format is supported. pub unsafe fn is_format_supported( client: *const IAudioClient, waveformatex_ptr: *const mmreg::WAVEFORMATEX, ) -> Result { /* // `IsFormatSupported` checks whether the format is supported and fills // a `WAVEFORMATEX` let mut dummy_fmt_ptr: *mut mmreg::WAVEFORMATEX = mem::uninitialized(); let hresult = (*audio_client) .IsFormatSupported(share_mode, &format_attempt.Format, &mut dummy_fmt_ptr); // we free that `WAVEFORMATEX` immediately after because we don't need it if !dummy_fmt_ptr.is_null() { CoTaskMemFree(dummy_fmt_ptr as *mut _); } // `IsFormatSupported` can return `S_FALSE` (which means that a compatible format // has been found) but we also treat this as an error match (hresult, check_result(hresult)) { (_, Err(ref e)) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => { (*audio_client).Release(); return Err(CreationError::DeviceNotAvailable); }, (_, Err(e)) => { (*audio_client).Release(); panic!("{:?}", e); }, (winerror::S_FALSE, _) => { (*audio_client).Release(); return Err(CreationError::FormatNotSupported); }, (_, Ok(())) => (), }; */ // Check if the given format is supported. let is_supported = |waveformatex_ptr, mut closest_waveformatex_ptr| { let result = (*client).IsFormatSupported( AUDCLNT_SHAREMODE_SHARED, waveformatex_ptr, &mut closest_waveformatex_ptr, ); // `IsFormatSupported` can return `S_FALSE` (which means that a compatible format // has been found, but not an exact match) so we also treat this as unsupported. match (result, check_result(result)) { (_, Err(ref e)) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => { return Err(FormatsEnumerationError::DeviceNotAvailable); }, (_, Err(_)) => { Ok(false) }, (winerror::S_FALSE, _) => { Ok(false) }, (_, Ok(())) => { Ok(true) }, } }; // First we want to retrieve a pointer to the `WAVEFORMATEX`. // Although `GetMixFormat` writes the format to a given `WAVEFORMATEX` pointer, // the pointer itself may actually point to a `WAVEFORMATEXTENSIBLE` structure. // We check the wFormatTag to determine this and get a pointer to the correct type. match (*waveformatex_ptr).wFormatTag { mmreg::WAVE_FORMAT_PCM | mmreg::WAVE_FORMAT_IEEE_FLOAT => { let mut closest_waveformatex = *waveformatex_ptr; let mut closest_waveformatex_ptr = &mut closest_waveformatex as *mut _; is_supported(waveformatex_ptr, closest_waveformatex_ptr) }, mmreg::WAVE_FORMAT_EXTENSIBLE => { let waveformatextensible_ptr = waveformatex_ptr as *const mmreg::WAVEFORMATEXTENSIBLE; let mut closest_waveformatextensible = *waveformatextensible_ptr; let closest_waveformatextensible_ptr = &mut closest_waveformatextensible as *mut _; let mut closest_waveformatex_ptr = closest_waveformatextensible_ptr as *mut mmreg::WAVEFORMATEX; is_supported(waveformatex_ptr, closest_waveformatex_ptr) }, _ => Ok(false), } } // Get a cpal Format from a WAVEFORMATEX. unsafe fn format_from_waveformatex_ptr( waveformatex_ptr: *const mmreg::WAVEFORMATEX, ) -> Option { fn cmp_guid(a: &GUID, b: &GUID) -> bool { a.Data1 == b.Data1 && a.Data2 == b.Data2 && a.Data3 == b.Data3 && a.Data4 == b.Data4 } let data_type = match ((*waveformatex_ptr).wBitsPerSample, (*waveformatex_ptr).wFormatTag) { (16, mmreg::WAVE_FORMAT_PCM) => SampleFormat::I16, (32, mmreg::WAVE_FORMAT_IEEE_FLOAT) => SampleFormat::F32, (n_bits, mmreg::WAVE_FORMAT_EXTENSIBLE) => { let waveformatextensible_ptr = waveformatex_ptr as *const mmreg::WAVEFORMATEXTENSIBLE; let sub = (*waveformatextensible_ptr).SubFormat; if n_bits == 16 && cmp_guid(&sub, &ksmedia::KSDATAFORMAT_SUBTYPE_PCM) { SampleFormat::I16 } else if n_bits == 32 && cmp_guid(&sub, &ksmedia::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) { SampleFormat::F32 } else { return None; } }, // Unknown data format returned by GetMixFormat. _ => return None, }; let format = Format { channels: (*waveformatex_ptr).nChannels as _, sample_rate: SampleRate((*waveformatex_ptr).nSamplesPerSec), data_type: data_type, }; Some(format) } unsafe impl Send for Device { } unsafe impl Sync for Device { } impl Device { pub fn name(&self) -> String { unsafe { // Open the device's property store. let mut property_store = ptr::null_mut(); (*self.device).OpenPropertyStore(coml2api::STGM_READ, &mut property_store); // Get the endpoint's friendly-name property. let mut property_value = mem::zeroed(); check_result( (*property_store).GetValue( &devpkey::DEVPKEY_Device_FriendlyName as *const _ as *const _, &mut property_value ) ).expect("failed to get friendly-name from property store"); // Read the friendly-name from the union data field, expecting a *const u16. assert_eq!(property_value.vt, wtypes::VT_LPWSTR as _); let ptr_usize: usize = *(&property_value.data as *const _ as *const usize); let ptr_utf16 = ptr_usize as *const u16; // Find the length of the friendly name. let mut len = 0; while *ptr_utf16.offset(len) != 0 { len += 1; } // Create the utf16 slice and covert it into a string. let name_slice = slice::from_raw_parts(ptr_utf16, len as usize); let name_os_string: OsString = OsStringExt::from_wide(name_slice); let name_string = name_os_string.into_string().unwrap(); // Clean up the property. PropVariantClear(&mut property_value); name_string } } #[inline] fn from_immdevice(device: *mut IMMDevice) -> Self { Device { device: device, future_audio_client: Arc::new(Mutex::new(None)), } } /// Ensures that `future_audio_client` contains a `Some` and returns a locked mutex to it. fn ensure_future_audio_client(&self) -> Result>, IoError> { let mut lock = self.future_audio_client.lock().unwrap(); if lock.is_some() { return Ok(lock); } let audio_client: *mut IAudioClient = unsafe { let mut audio_client = mem::uninitialized(); let hresult = (*self.device).Activate(&IID_IAudioClient, CLSCTX_ALL, ptr::null_mut(), &mut audio_client); // can fail if the device has been disconnected since we enumerated it, or if // the device doesn't support playback for some reason check_result(hresult)?; assert!(!audio_client.is_null()); audio_client as *mut _ }; *lock = Some(IAudioClientWrapper(audio_client)); Ok(lock) } /// Returns an uninitialized `IAudioClient`. #[inline] pub(crate) fn build_audioclient(&self) -> Result<*mut IAudioClient, IoError> { let mut lock = self.ensure_future_audio_client()?; let client = lock.unwrap().0; *lock = None; Ok(client) } // There is no way to query the list of all formats that are supported by the // audio processor, so instead we just trial some commonly supported formats. // // Common formats are trialed by first getting the default format (returned via // `GetMixFormat`) and then mutating that format with common sample rates and // querying them via `IsFormatSupported`. // // When calling `IsFormatSupported` with the shared-mode audio engine, only the default // number of channels seems to be supported. Any more or less returns an invalid // parameter error. Thus we just assume that the default number of channels is the only // number supported. fn supported_formats(&self) -> Result { // initializing COM because we call `CoTaskMemFree` to release the format. com::com_initialized(); // Retrieve the `IAudioClient`. let lock = match self.ensure_future_audio_client() { Err(ref e) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => return Err(FormatsEnumerationError::DeviceNotAvailable), e => e.unwrap(), }; let client = lock.unwrap().0; unsafe { // Retrieve the pointer to the default WAVEFORMATEX. let mut default_waveformatex_ptr = WaveFormatExPtr(mem::uninitialized()); match check_result((*client).GetMixFormat(&mut default_waveformatex_ptr.0)) { Err(ref e) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => { return Err(FormatsEnumerationError::DeviceNotAvailable); }, Err(e) => panic!("{:?}", e), Ok(()) => (), }; // If the default format can't succeed we have no hope of finding other formats. assert_eq!(try!(is_format_supported(client, default_waveformatex_ptr.0)), true); // Copy the format to use as a test format (as to avoid mutating the original format). let mut test_format = { match WaveFormat::copy_from_waveformatex_ptr(default_waveformatex_ptr.0) { Some(f) => f, // If the format is neither EX or EXTENSIBLE we don't know how to work with it. None => return Ok(vec![].into_iter()), } }; // Begin testing common sample rates. // // NOTE: We should really be testing for whole ranges here, but it is infeasible to // test every sample rate up to the overflow limit as the `IsFormatSupported` method is // quite slow. let mut supported_sample_rates: Vec = Vec::new(); for &rate in COMMON_SAMPLE_RATES { let rate = rate.0 as DWORD; test_format.nSamplesPerSec = rate; test_format.nAvgBytesPerSec = rate * (*default_waveformatex_ptr.0).nBlockAlign as DWORD; if try!(is_format_supported(client, test_format.as_ptr())) { supported_sample_rates.push(rate); } } // If the common rates don't include the default one, add the default. let default_sr = (*default_waveformatex_ptr.0).nSamplesPerSec as _; if !supported_sample_rates.iter().any(|&r| r == default_sr) { supported_sample_rates.push(default_sr); } // Reset the sample rate on the test format now that we're done. test_format.nSamplesPerSec = (*default_waveformatex_ptr.0).nSamplesPerSec; test_format.nAvgBytesPerSec = (*default_waveformatex_ptr.0).nAvgBytesPerSec; // TODO: Test the different sample formats? // Create the supported formats. let mut format = format_from_waveformatex_ptr(default_waveformatex_ptr.0) .expect("could not create a cpal::Format from a WAVEFORMATEX"); let mut supported_formats = Vec::with_capacity(supported_sample_rates.len()); for rate in supported_sample_rates { format.sample_rate = SampleRate(rate as _); supported_formats.push(SupportedFormat::from(format.clone())); } Ok(supported_formats.into_iter()) } } pub fn supported_input_formats(&self) -> Result { if self.data_flow() == eCapture { self.supported_formats() // If it's an output device, assume no input formats. } else { Ok(vec![].into_iter()) } } pub fn supported_output_formats(&self) -> Result { if self.data_flow() == eRender { self.supported_formats() // If it's an input device, assume no output formats. } else { Ok(vec![].into_iter()) } } // We always create voices in shared mode, therefore all samples go through an audio // processor to mix them together. // // One format is guaranteed to be supported, the one returned by `GetMixFormat`. fn default_format(&self) -> Result { // initializing COM because we call `CoTaskMemFree` com::com_initialized(); let lock = match self.ensure_future_audio_client() { Err(ref e) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => return Err(DefaultFormatError::DeviceNotAvailable), e => e.unwrap(), }; let client = lock.unwrap().0; unsafe { let mut format_ptr = WaveFormatExPtr(mem::uninitialized()); match check_result((*client).GetMixFormat(&mut format_ptr.0)) { Err(ref e) if e.raw_os_error() == Some(AUDCLNT_E_DEVICE_INVALIDATED) => { return Err(DefaultFormatError::DeviceNotAvailable); }, Err(e) => panic!("{:?}", e), Ok(()) => (), }; format_from_waveformatex_ptr(format_ptr.0) .ok_or(DefaultFormatError::StreamTypeNotSupported) } } fn data_flow(&self) -> EDataFlow { let endpoint = Endpoint::from(self.device as *const _); endpoint.data_flow() } pub fn default_input_format(&self) -> Result { if self.data_flow() == eCapture { self.default_format() } else { Err(DefaultFormatError::StreamTypeNotSupported) } } pub fn default_output_format(&self) -> Result { let data_flow = self.data_flow(); if data_flow == eRender { self.default_format() } else { Err(DefaultFormatError::StreamTypeNotSupported) } } } impl PartialEq for Device { #[inline] fn eq(&self, other: &Device) -> bool { self.device == other.device } } impl Eq for Device { } impl Clone for Device { #[inline] fn clone(&self) -> Device { unsafe { (*self.device).AddRef(); } Device { device: self.device, future_audio_client: self.future_audio_client.clone(), } } } impl Drop for Device { #[inline] fn drop(&mut self) { unsafe { (*self.device).Release(); } if let Some(client) = self.future_audio_client.lock().unwrap().take() { unsafe { (*client.0).Release(); } } } } impl Drop for Endpoint { fn drop(&mut self) { unsafe { (*self.endpoint).Release(); } } } impl From<*const IMMDevice> for Endpoint { fn from(device: *const IMMDevice) -> Self { unsafe { let endpoint = immendpoint_from_immdevice(device); Endpoint { endpoint: endpoint } } } } impl Endpoint { fn data_flow(&self) -> EDataFlow { unsafe { data_flow_from_immendpoint(self.endpoint) } } } lazy_static! { static ref ENUMERATOR: Enumerator = { // COM initialization is thread local, but we only need to have COM initialized in the // thread we create the objects in com::com_initialized(); // building the devices enumerator object unsafe { let mut enumerator: *mut IMMDeviceEnumerator = mem::uninitialized(); let hresult = CoCreateInstance( &CLSID_MMDeviceEnumerator, ptr::null_mut(), CLSCTX_ALL, &IMMDeviceEnumerator::uuidof(), &mut enumerator as *mut *mut IMMDeviceEnumerator as *mut _, ); check_result(hresult).unwrap(); Enumerator(enumerator) } }; } /// RAII object around `IMMDeviceEnumerator`. struct Enumerator(*mut IMMDeviceEnumerator); unsafe impl Send for Enumerator { } unsafe impl Sync for Enumerator { } impl Drop for Enumerator { #[inline] fn drop(&mut self) { unsafe { (*self.0).Release(); } } } /// WASAPI implementation for `Devices`. pub struct Devices { collection: *mut IMMDeviceCollection, total_count: u32, next_item: u32, } unsafe impl Send for Devices { } unsafe impl Sync for Devices { } impl Drop for Devices { #[inline] fn drop(&mut self) { unsafe { (*self.collection).Release(); } } } impl Default for Devices { fn default() -> Devices { unsafe { let mut collection: *mut IMMDeviceCollection = mem::uninitialized(); // can fail because of wrong parameters (should never happen) or out of memory check_result( (*ENUMERATOR.0).EnumAudioEndpoints( eAll, DEVICE_STATE_ACTIVE, &mut collection, ) ).unwrap(); let mut count = mem::uninitialized(); // can fail if the parameter is null, which should never happen check_result((*collection).GetCount(&mut count)).unwrap(); Devices { collection: collection, total_count: count, next_item: 0, } } } } impl Iterator for Devices { type Item = Device; fn next(&mut self) -> Option { if self.next_item >= self.total_count { return None; } unsafe { let mut device = mem::uninitialized(); // can fail if out of range, which we just checked above check_result((*self.collection).Item(self.next_item, &mut device)).unwrap(); self.next_item += 1; Some(Device::from_immdevice(device)) } } #[inline] fn size_hint(&self) -> (usize, Option) { let num = self.total_count - self.next_item; let num = num as usize; (num, Some(num)) } } fn default_device(data_flow: EDataFlow) -> Option { unsafe { let mut device = mem::uninitialized(); let hres = (*ENUMERATOR.0) .GetDefaultAudioEndpoint(data_flow, eConsole, &mut device); if let Err(_err) = check_result(hres) { return None; // TODO: check specifically for `E_NOTFOUND`, and panic otherwise } Some(Device::from_immdevice(device)) } } pub fn default_input_device() -> Option { default_device(eCapture) } pub fn default_output_device() -> Option { default_device(eRender) }