#![no_std]
#![allow(clippy::missing_safety_doc)]
use core::{
cell::UnsafeCell,
marker::PhantomData,
mem::MaybeUninit,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
use maitake::sync::{WaitCell, WaitQueue};
pub unsafe trait Storage<T> {
fn buf(&self) -> (*const UnsafeCell<Cell<T>>, usize);
}
pub struct MpScQueue<T, STO: Storage<T>> {
storage: STO,
dequeue_pos: AtomicUsize,
enqueue_pos: AtomicUsize,
cons_wait: WaitCell,
prod_wait: WaitQueue,
closed: AtomicBool,
pd: PhantomData<T>,
}
#[derive(Debug, Eq, PartialEq)]
pub enum EnqueueError<T> {
Full(T),
Closed(T),
}
#[derive(Debug, Eq, PartialEq)]
pub enum DequeueError {
Closed,
}
impl<T, STO: Storage<T>> MpScQueue<T, STO> {
#[track_caller]
pub fn new(storage: STO) -> Self {
let (ptr, len) = storage.buf();
assert_eq!(
len,
len.next_power_of_two(),
"Capacity must be a power of two!"
);
assert!(len > 1, "Capacity must be larger than 1!");
let sli = unsafe { core::slice::from_raw_parts(ptr, len) };
sli.iter().enumerate().for_each(|(i, slot)| unsafe {
slot.get().write(Cell {
data: MaybeUninit::uninit(),
sequence: AtomicUsize::new(i),
});
});
Self {
storage,
dequeue_pos: AtomicUsize::new(0),
enqueue_pos: AtomicUsize::new(0),
cons_wait: WaitCell::new(),
prod_wait: WaitQueue::new(),
closed: AtomicBool::new(false),
pd: PhantomData,
}
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
self.cons_wait.close();
self.prod_wait.close();
}
pub fn dequeue_sync(&self) -> Option<T> {
let (ptr, len) = self.storage.buf();
let res = unsafe { dequeue((*ptr).get(), &self.dequeue_pos, len - 1) };
if res.is_some() {
self.prod_wait.wake_all();
}
res
}
pub fn enqueue_sync(&self, item: T) -> Result<(), EnqueueError<T>> {
if self.closed.load(Ordering::Acquire) {
return Err(EnqueueError::Closed(item));
}
let (ptr, len) = self.storage.buf();
let res = unsafe { enqueue((*ptr).get(), &self.enqueue_pos, len - 1, item) };
if res.is_ok() {
self.cons_wait.wake();
}
res.map_err(EnqueueError::Full)
}
pub async fn enqueue_async(&self, mut item: T) -> Result<(), EnqueueError<T>> {
loop {
match self.enqueue_sync(item) {
ok @ Ok(_) => return ok,
err @ Err(EnqueueError::Closed(_)) => return err,
Err(EnqueueError::Full(eitem)) => {
match self.prod_wait.wait().await {
Ok(()) => {}
Err(_) => return Err(EnqueueError::Closed(eitem)),
}
item = eitem;
}
}
}
}
pub async fn dequeue_async(&self) -> Result<T, DequeueError> {
loop {
let wait = self.cons_wait.subscribe().await;
match self.dequeue_sync() {
Some(t) => return Ok(t),
None => match wait.await {
Ok(()) => {}
Err(_) => return Err(DequeueError::Closed),
},
}
}
}
}
unsafe impl<T, STO: Storage<T>> Sync for MpScQueue<T, STO> where T: Send {}
impl<T, STO: Storage<T>> Drop for MpScQueue<T, STO> {
fn drop(&mut self) {
while self.dequeue_sync().is_some() {}
self.cons_wait.close();
self.prod_wait.close();
}
}
pub struct Cell<T> {
data: MaybeUninit<T>,
sequence: AtomicUsize,
}
pub const fn single_cell<T>() -> Cell<T> {
Cell {
data: MaybeUninit::uninit(),
sequence: AtomicUsize::new(0),
}
}
pub fn cell_array<const N: usize, T: Sized>() -> [Cell<T>; N] {
[Cell::<T>::SINGLE_CELL; N]
}
impl<T> Cell<T> {
#[allow(clippy::declare_interior_mutable_const)]
const SINGLE_CELL: Self = Self::new(0);
const fn new(seq: usize) -> Self {
Self {
data: MaybeUninit::uninit(),
sequence: AtomicUsize::new(seq),
}
}
}
unsafe fn dequeue<T>(buffer: *mut Cell<T>, dequeue_pos: &AtomicUsize, mask: usize) -> Option<T> {
let mut pos = dequeue_pos.load(Ordering::Relaxed);
let mut cell;
loop {
cell = buffer.add(pos & mask);
let seq = (*cell).sequence.load(Ordering::Acquire);
let dif = (seq as i8).wrapping_sub((pos.wrapping_add(1)) as i8);
match dif {
0 => {
if dequeue_pos
.compare_exchange_weak(
pos,
pos.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
dif if dif < 0 => return None,
_ => pos = dequeue_pos.load(Ordering::Relaxed),
}
}
let data = (*cell).data.as_ptr().read();
(*cell)
.sequence
.store(pos.wrapping_add(mask).wrapping_add(1), Ordering::Release);
Some(data)
}
unsafe fn enqueue<T>(
buffer: *mut Cell<T>,
enqueue_pos: &AtomicUsize,
mask: usize,
item: T,
) -> Result<(), T> {
let mut pos = enqueue_pos.load(Ordering::Relaxed);
let mut cell;
loop {
cell = buffer.add(pos & mask);
let seq = (*cell).sequence.load(Ordering::Acquire);
let dif = (seq as i8).wrapping_sub(pos as i8);
match dif {
0 => {
if enqueue_pos
.compare_exchange_weak(
pos,
pos.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
break;
}
}
dif if dif < 0 => return Err(item),
_ => pos = enqueue_pos.load(Ordering::Relaxed),
}
}
(*cell).data.as_mut_ptr().write(item);
(*cell)
.sequence
.store(pos.wrapping_add(1), Ordering::Release);
Ok(())
}