From 47f966bf7523b23ab6ca5ac3fb95c200d6752788 Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Tue, 1 Sep 2015 11:23:41 +0200 Subject: [PATCH] Correctly enumerate audio devices (core + wasapi) --- Cargo.toml | 1 + examples/beep.rs | 2 +- src/lib.rs | 47 +++++++- src/wasapi/com.rs | 32 +++++ src/wasapi/enumerate.rs | 126 +++++++++++++++++++ src/wasapi/mod.rs | 259 ++++------------------------------------ src/wasapi/voice.rs | 231 +++++++++++++++++++++++++++++++++++ 7 files changed, 457 insertions(+), 241 deletions(-) create mode 100644 src/wasapi/com.rs create mode 100644 src/wasapi/enumerate.rs create mode 100644 src/wasapi/voice.rs diff --git a/Cargo.toml b/Cargo.toml index 5a0f1e5..c937984 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ keywords = ["audio", "sound"] [dependencies] libc = "*" +lazy_static = "0.1" [target.i686-pc-windows-gnu.dependencies] winapi = "0.2.1" diff --git a/examples/beep.rs b/examples/beep.rs index 149c2c4..764f724 100644 --- a/examples/beep.rs +++ b/examples/beep.rs @@ -1,7 +1,7 @@ extern crate cpal; fn main() { - let mut channel = cpal::Voice::new(); + let mut channel = cpal::Voice::new(&cpal::get_default_endpoint().unwrap()).unwrap(); // Produce a sinusoid of maximum amplitude. let mut data_source = (0u64..).map(|t| t as f32 * 0.03) diff --git a/src/lib.rs b/src/lib.rs index e3e8a90..cab4f22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,8 +28,12 @@ reaches the end of the data, it will stop playing. You must continuously fill th calling `append_data` repeatedly if you don't want the audio to stop playing. */ +#[macro_use] +extern crate lazy_static; + pub use samples_formats::{SampleFormat, Sample}; +use std::error::Error; use std::ops::{Deref, DerefMut}; mod samples_formats; @@ -57,6 +61,43 @@ pub type ChannelsCount = u16; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct SamplesRate(pub u32); +/// Describes a format. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct Format { + pub channels: ChannelsCount, + pub samples_rate: SamplesRate, + pub data_type: SampleFormat, +} + +/// An iterator for the list of formats that are supported by the backend. +pub struct EndpointsIterator(cpal_impl::EndpointsIterator); + +impl Iterator for EndpointsIterator { + type Item = Endpoint; + + fn next(&mut self) -> Option { + self.0.next().map(Endpoint) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +/// Return an iterator to the list of formats that are supported by the system. +pub fn get_endpoints_list() -> EndpointsIterator { + EndpointsIterator(Default::default()) +} + +/// Return the default endpoint, or `None` if no device is available. +pub fn get_default_endpoint() -> Option { + cpal_impl::get_default_endpoint().map(Endpoint) +} + +/// An opaque type that identifies an end point. +#[derive(Clone, PartialEq, Eq)] +pub struct Endpoint(cpal_impl::Endpoint); + /// Represents a buffer that must be filled with audio data. /// /// You should destroy this object as soon as possible. Data is only committed when it @@ -95,9 +136,9 @@ pub struct Voice(cpal_impl::Voice); impl Voice { /// Builds a new channel. - pub fn new() -> Voice { - let channel = cpal_impl::Voice::new(); - Voice(channel) + pub fn new(endpoint: &Endpoint) -> Result> { + let channel = try!(cpal_impl::Voice::new(&endpoint.0)); + Ok(Voice(channel)) } /// Returns the number of channels. diff --git a/src/wasapi/com.rs b/src/wasapi/com.rs new file mode 100644 index 0000000..92a6f31 --- /dev/null +++ b/src/wasapi/com.rs @@ -0,0 +1,32 @@ +//! Handles COM initialization and cleanup. + +use std::ptr; +use super::winapi; +use super::ole32; +use super::check_result; + +thread_local!(static COM_INITIALIZED: ComInitialized = { + unsafe { + // this call can fail if another library initialized COM in single-threaded mode + // handling this situation properly would make the API more annoying, so we just don't care + check_result(ole32::CoInitializeEx(ptr::null_mut(), winapi::COINIT_MULTITHREADED)).unwrap(); + ComInitialized(ptr::null_mut()) + } +}); + +/// RAII object that guards the fact that COM is initialized. +/// +// We store a raw pointer because it's the only way at the moment to remove `Send`/`Sync` from the +// object. +struct ComInitialized(*mut ()); + +impl Drop for ComInitialized { + fn drop(&mut self) { + unsafe { ole32::CoUninitialize() }; + } +} + +/// Ensures that COM is initialized in this thread. +pub fn com_initialized() { + COM_INITIALIZED.with(|_| {}); +} diff --git a/src/wasapi/enumerate.rs b/src/wasapi/enumerate.rs new file mode 100644 index 0000000..fbacbc6 --- /dev/null +++ b/src/wasapi/enumerate.rs @@ -0,0 +1,126 @@ +use super::winapi; +use super::ole32; +use super::com; +use super::Endpoint; +use super::check_result; + +use std::mem; +use std::ptr; + +lazy_static! { + static ref ENUMERATOR: Enumerator = { + // COM initialization is thread local, but we only need to have COM initialized in the + // thread we create the objects in + com::com_initialized(); + + // building the devices enumerator object + unsafe { + let mut enumerator: *mut winapi::IMMDeviceEnumerator = mem::uninitialized(); + + let hresult = ole32::CoCreateInstance(&winapi::CLSID_MMDeviceEnumerator, + ptr::null_mut(), winapi::CLSCTX_ALL, + &winapi::IID_IMMDeviceEnumerator, + &mut enumerator + as *mut *mut winapi::IMMDeviceEnumerator + as *mut _); + + check_result(hresult).unwrap(); + Enumerator(enumerator) + } + }; +} + +/// RAII object around `winapi::IMMDeviceEnumerator`. +struct Enumerator(*mut winapi::IMMDeviceEnumerator); + +unsafe impl Send for Enumerator {} +unsafe impl Sync for Enumerator {} + +impl Drop for Enumerator { + fn drop(&mut self) { + unsafe { + (*self.0).Release(); + } + } +} + +/// WASAPI implementation for `EndpointsIterator`. +pub struct EndpointsIterator { + collection: *mut winapi::IMMDeviceCollection, + total_count: u32, + next_item: u32, +} + +unsafe impl Send for EndpointsIterator {} +unsafe impl Sync for EndpointsIterator {} + +impl Drop for EndpointsIterator { + fn drop(&mut self) { + unsafe { + (*self.collection).Release(); + } + } +} + +impl Default for EndpointsIterator { + fn default() -> EndpointsIterator { + unsafe { + let mut collection: *mut winapi::IMMDeviceCollection = mem::uninitialized(); + // can fail because of wrong parameters (should never happen) or out of memory + check_result((*ENUMERATOR.0).EnumAudioEndpoints(winapi::EDataFlow::eRender, + winapi::DEVICE_STATE_ACTIVE, + &mut collection)) + .unwrap(); + + let mut count = mem::uninitialized(); + // can fail if the parameter is null, which should never happen + check_result((*collection).GetCount(&mut count)).unwrap(); + + EndpointsIterator { + collection: collection, + total_count: count, + next_item: 0, + } + } + } +} + +impl Iterator for EndpointsIterator { + type Item = Endpoint; + + fn next(&mut self) -> Option { + if self.next_item >= self.total_count { + return None; + } + + unsafe { + let mut device = mem::uninitialized(); + // can fail if out of range, which we just checked above + check_result((*self.collection).Item(self.next_item, &mut device)).unwrap(); + + self.next_item += 1; + Some(Endpoint(device)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let num = self.total_count - self.next_item; + let num = num as usize; + (num, Some(num)) + } +} + +pub fn get_default_endpoint() -> Option { + unsafe { + let mut device = mem::uninitialized(); + let hres = (*ENUMERATOR.0).GetDefaultAudioEndpoint(winapi::EDataFlow::eRender, + winapi::ERole::eConsole, &mut device); + + if let Err(_err) = check_result(hres) { + return None; // TODO: check specifically for `E_NOTFOUND`, and panic otherwise + } + + Some(Endpoint(device)) + } +} diff --git a/src/wasapi/mod.rs b/src/wasapi/mod.rs index 28997d6..9c2917d 100644 --- a/src/wasapi/mod.rs +++ b/src/wasapi/mod.rs @@ -2,255 +2,40 @@ extern crate libc; extern crate winapi; extern crate ole32; -use std::{cmp, slice, mem, ptr}; -use std::marker::PhantomData; +use std::io::Error as IoError; -// TODO: determine if should be NoSend or not -pub struct Voice { - audio_client: *mut winapi::IAudioClient, - render_client: *mut winapi::IAudioRenderClient, - max_frames_in_buffer: winapi::UINT32, - num_channels: winapi::WORD, - bytes_per_frame: winapi::WORD, - samples_per_second: winapi::DWORD, - bits_per_sample: winapi::WORD, - playing: bool, -} +pub use self::enumerate::{EndpointsIterator, get_default_endpoint}; +pub use self::voice::{Voice, Buffer}; -pub struct Buffer<'a, T: 'a> { - render_client: *mut winapi::IAudioRenderClient, - buffer_data: *mut T, - buffer_len: usize, - frames: winapi::UINT32, - marker: PhantomData<&'a mut T>, -} +mod com; +mod enumerate; +mod voice; -impl Voice { - pub fn new() -> Voice { - init().unwrap() - } +/// An opaque type that identifies an end point. +#[derive(PartialEq, Eq)] +#[allow(raw_pointer_derive)] +pub struct Endpoint(*mut winapi::IMMDevice); - pub fn get_channels(&self) -> ::ChannelsCount { - self.num_channels as ::ChannelsCount - } +unsafe impl Send for Endpoint {} +unsafe impl Sync for Endpoint {} - pub fn get_samples_rate(&self) -> ::SamplesRate { - ::SamplesRate(self.samples_per_second as u32) - } - - pub fn get_samples_format(&self) -> ::SampleFormat { - match self.bits_per_sample { - 16 => ::SampleFormat::I16, - _ => panic!("{}-bit format not yet supported", self.bits_per_sample), - } - } - - pub fn append_data<'a, T>(&'a mut self, max_elements: usize) -> Buffer<'a, T> { - unsafe { - loop { - // - let frames_available = { - let mut padding = mem::uninitialized(); - let hresult = (*self.audio_client).GetCurrentPadding(&mut padding); - check_result(hresult).unwrap(); - self.max_frames_in_buffer - padding - }; - - if frames_available == 0 { - // TODO: - ::std::thread::sleep_ms(1); - continue; - } - - let frames_available = cmp::min(frames_available, - max_elements as u32 * mem::size_of::() as u32 / - self.bytes_per_frame as u32); - assert!(frames_available != 0); - - // loading buffer - let (buffer_data, buffer_len) = { - let mut buffer: *mut winapi::BYTE = mem::uninitialized(); - let hresult = (*self.render_client).GetBuffer(frames_available, - &mut buffer as *mut *mut libc::c_uchar); - check_result(hresult).unwrap(); - assert!(!buffer.is_null()); - - (buffer as *mut T, - frames_available as usize * self.bytes_per_frame as usize - / mem::size_of::()) - }; - - let buffer = Buffer { - render_client: self.render_client, - buffer_data: buffer_data, - buffer_len: buffer_len, - frames: frames_available, - marker: PhantomData, - }; - - return buffer; - } - } - } - - pub fn play(&mut self) { - if !self.playing { - unsafe { - let hresult = (*self.audio_client).Start(); - check_result(hresult).unwrap(); - } - } - - self.playing = true; - } - - pub fn pause(&mut self) { - if self.playing { - unsafe { - let hresult = (*self.audio_client).Stop(); - check_result(hresult).unwrap(); - } - } - - self.playing = false; +impl Clone for Endpoint { + fn clone(&self) -> Endpoint { + unsafe { (*self.0).AddRef(); } + Endpoint(self.0) } } -unsafe impl Send for Voice {} -unsafe impl Sync for Voice {} - -impl Drop for Voice { +impl Drop for Endpoint { fn drop(&mut self) { - unsafe { - (*self.render_client).Release(); - (*self.audio_client).Release(); - } + unsafe { (*self.0).Release(); } } } -impl<'a, T> Buffer<'a, T> { - pub fn get_buffer<'b>(&'b mut self) -> &'b mut [T] { - unsafe { - slice::from_raw_parts_mut(self.buffer_data, self.buffer_len) - } - } - - pub fn finish(self) { - // releasing buffer - unsafe { - let hresult = (*self.render_client).ReleaseBuffer(self.frames as u32, 0); - check_result(hresult).unwrap(); - }; - } -} - -fn init() -> Result { - // FIXME: release everything - unsafe { - try!(check_result(ole32::CoInitializeEx(ptr::null_mut(), 0))); - - // building the devices enumerator object - let enumerator = { - let mut enumerator: *mut winapi::IMMDeviceEnumerator = mem::uninitialized(); - - let hresult = ole32::CoCreateInstance(&winapi::CLSID_MMDeviceEnumerator, - ptr::null_mut(), winapi::CLSCTX_ALL, - &winapi::IID_IMMDeviceEnumerator, - mem::transmute(&mut enumerator)); - - try!(check_result(hresult)); - &mut *enumerator - }; - - // getting the default end-point - let device = { - let mut device: *mut winapi::IMMDevice = mem::uninitialized(); - let hresult = enumerator.GetDefaultAudioEndpoint(winapi::EDataFlow::eRender, winapi::ERole::eConsole, - mem::transmute(&mut device)); - try!(check_result(hresult)); - &mut *device - }; - - // activating in order to get a `IAudioClient` - let audio_client: &mut winapi::IAudioClient = { - let mut audio_client: *mut winapi::IAudioClient = mem::uninitialized(); - let hresult = device.Activate(&winapi::IID_IAudioClient, winapi::CLSCTX_ALL, - ptr::null_mut(), mem::transmute(&mut audio_client)); - try!(check_result(hresult)); - &mut *audio_client - }; - - // computing the format and initializing the device - let format = { - let format_attempt = winapi::WAVEFORMATEX { - wFormatTag: 1, // WAVE_FORMAT_PCM ; TODO: replace by constant - nChannels: 2, - nSamplesPerSec: 44100, - nAvgBytesPerSec: 2 * 44100 * 2, - nBlockAlign: (2 * 16) / 8, - wBitsPerSample: 16, - cbSize: 0, - }; - - let mut format_ptr: *mut winapi::WAVEFORMATEX = mem::uninitialized(); - let hresult = audio_client.IsFormatSupported(winapi::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED, - &format_attempt, &mut format_ptr); - try!(check_result(hresult)); - - let format = if format_ptr.is_null() { - &format_attempt - } else { - &*format_ptr - }; - - let format_copy = ptr::read(format); - - let hresult = audio_client.Initialize(winapi::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED, - 0, 10000000, 0, format, ptr::null()); - - if !format_ptr.is_null() { - ole32::CoTaskMemFree(format_ptr as *mut _); - } - - try!(check_result(hresult)); - - format_copy - }; - - // - let max_frames_in_buffer = { - let mut max_frames_in_buffer = mem::uninitialized(); - let hresult = audio_client.GetBufferSize(&mut max_frames_in_buffer); - try!(check_result(hresult)); - max_frames_in_buffer - }; - - // - let render_client = { - let mut render_client: *mut winapi::IAudioRenderClient = mem::uninitialized(); - let hresult = audio_client.GetService(&winapi::IID_IAudioRenderClient, - mem::transmute(&mut render_client)); - try!(check_result(hresult)); - &mut *render_client - }; - - Ok(Voice { - audio_client: audio_client, - render_client: render_client, - max_frames_in_buffer: max_frames_in_buffer, - num_channels: format.nChannels, - bytes_per_frame: format.nBlockAlign, - samples_per_second: format.nSamplesPerSec, - bits_per_sample: format.wBitsPerSample, - playing: false, - }) - } -} - -fn check_result(result: winapi::HRESULT) -> Result<(), String> { +fn check_result(result: winapi::HRESULT) -> Result<(), IoError> { if result < 0 { - return Err(format!("Error in winapi call")); // TODO: + Err(IoError::from_raw_os_error(result)) + } else { + Ok(()) } - - Ok(()) } diff --git a/src/wasapi/voice.rs b/src/wasapi/voice.rs new file mode 100644 index 0000000..797fb72 --- /dev/null +++ b/src/wasapi/voice.rs @@ -0,0 +1,231 @@ +use super::com; +use super::ole32; +use super::winapi; +use super::Endpoint; +use super::check_result; + +use std::io::Error as IoError; +use std::cmp; +use std::slice; +use std::mem; +use std::ptr; +use std::marker::PhantomData; + +pub struct Voice { + audio_client: *mut winapi::IAudioClient, + render_client: *mut winapi::IAudioRenderClient, + max_frames_in_buffer: winapi::UINT32, + num_channels: winapi::WORD, + bytes_per_frame: winapi::WORD, + samples_per_second: winapi::DWORD, + bits_per_sample: winapi::WORD, + playing: bool, +} + +unsafe impl Send for Voice {} +unsafe impl Sync for Voice {} + +impl Voice { + pub fn new(end_point: &Endpoint) -> Result { + // FIXME: release everything + unsafe { + // making sure that COM is initialized + // it's not actually sure that this is required, but when in doubt do it + com::com_initialized(); + + // activating the end point in order to get a `IAudioClient` + let audio_client: *mut winapi::IAudioClient = { + let mut audio_client = mem::uninitialized(); + let hresult = (*end_point.0).Activate(&winapi::IID_IAudioClient, winapi::CLSCTX_ALL, + ptr::null_mut(), &mut audio_client); + // can fail if the device has been disconnected since we enumerated it, or if + // the device doesn't support playback for some reason + try!(check_result(hresult)); + audio_client as *mut _ + }; + + // computing the format and initializing the device + let format = { + let format_attempt = winapi::WAVEFORMATEX { + wFormatTag: 1, // WAVE_FORMAT_PCM ; TODO: replace by constant + nChannels: 2, + nSamplesPerSec: 44100, + nAvgBytesPerSec: 2 * 44100 * 2, + nBlockAlign: (2 * 16) / 8, + wBitsPerSample: 16, + cbSize: 0, + }; + + let mut format_ptr: *mut winapi::WAVEFORMATEX = mem::uninitialized(); + let hresult = (*audio_client).IsFormatSupported(winapi::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED, + &format_attempt, &mut format_ptr); + try!(check_result(hresult)); + + let format = if format_ptr.is_null() { + &format_attempt + } else { + &*format_ptr + }; + + let format_copy = ptr::read(format); + + let hresult = (*audio_client).Initialize(winapi::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED, + 0, 10000000, 0, format, ptr::null()); + + if !format_ptr.is_null() { + ole32::CoTaskMemFree(format_ptr as *mut _); + } + + try!(check_result(hresult)); + + format_copy + }; + + // + let max_frames_in_buffer = { + let mut max_frames_in_buffer = mem::uninitialized(); + let hresult = (*audio_client).GetBufferSize(&mut max_frames_in_buffer); + try!(check_result(hresult)); + max_frames_in_buffer + }; + + // + let render_client = { + let mut render_client: *mut winapi::IAudioRenderClient = mem::uninitialized(); + let hresult = (*audio_client).GetService(&winapi::IID_IAudioRenderClient, + mem::transmute(&mut render_client)); + try!(check_result(hresult)); + &mut *render_client + }; + + Ok(Voice { + audio_client: audio_client, + render_client: render_client, + max_frames_in_buffer: max_frames_in_buffer, + num_channels: format.nChannels, + bytes_per_frame: format.nBlockAlign, + samples_per_second: format.nSamplesPerSec, + bits_per_sample: format.wBitsPerSample, + playing: false, + }) + } + } + + pub fn get_channels(&self) -> ::ChannelsCount { + self.num_channels as ::ChannelsCount + } + + pub fn get_samples_rate(&self) -> ::SamplesRate { + ::SamplesRate(self.samples_per_second as u32) + } + + pub fn get_samples_format(&self) -> ::SampleFormat { + match self.bits_per_sample { + 16 => ::SampleFormat::I16, + _ => panic!("{}-bit format not yet supported", self.bits_per_sample), + } + } + + pub fn append_data<'a, T>(&'a mut self, max_elements: usize) -> Buffer<'a, T> { + unsafe { + loop { + // + let frames_available = { + let mut padding = mem::uninitialized(); + let hresult = (*self.audio_client).GetCurrentPadding(&mut padding); + check_result(hresult).unwrap(); + self.max_frames_in_buffer - padding + }; + + if frames_available == 0 { + // TODO: + ::std::thread::sleep_ms(1); + continue; + } + + let frames_available = cmp::min(frames_available, + max_elements as u32 * mem::size_of::() as u32 / + self.bytes_per_frame as u32); + assert!(frames_available != 0); + + // loading buffer + let (buffer_data, buffer_len) = { + let mut buffer: *mut winapi::BYTE = mem::uninitialized(); + let hresult = (*self.render_client).GetBuffer(frames_available, + &mut buffer as *mut *mut _); + check_result(hresult).unwrap(); + assert!(!buffer.is_null()); + + (buffer as *mut T, + frames_available as usize * self.bytes_per_frame as usize + / mem::size_of::()) + }; + + let buffer = Buffer { + render_client: self.render_client, + buffer_data: buffer_data, + buffer_len: buffer_len, + frames: frames_available, + marker: PhantomData, + }; + + return buffer; + } + } + } + + pub fn play(&mut self) { + if !self.playing { + unsafe { + let hresult = (*self.audio_client).Start(); + check_result(hresult).unwrap(); + } + } + + self.playing = true; + } + + pub fn pause(&mut self) { + if self.playing { + unsafe { + let hresult = (*self.audio_client).Stop(); + check_result(hresult).unwrap(); + } + } + + self.playing = false; + } +} + +impl Drop for Voice { + fn drop(&mut self) { + unsafe { + (*self.render_client).Release(); + (*self.audio_client).Release(); + } + } +} + +pub struct Buffer<'a, T: 'a> { + render_client: *mut winapi::IAudioRenderClient, + buffer_data: *mut T, + buffer_len: usize, + frames: winapi::UINT32, + marker: PhantomData<&'a mut T>, +} + +impl<'a, T> Buffer<'a, T> { + pub fn get_buffer<'b>(&'b mut self) -> &'b mut [T] { + unsafe { + slice::from_raw_parts_mut(self.buffer_data, self.buffer_len) + } + } + + pub fn finish(self) { + // releasing buffer + unsafe { + let hresult = (*self.render_client).ReleaseBuffer(self.frames as u32, 0); + check_result(hresult).unwrap(); + }; + } +}