use crate::comms::bbq::GrantR;
use crate::{
comms::{bbq, oneshot::Reusable},
registry::{self, Envelope, KernelHandle, Message, RegisteredDriver},
services::simple_serial::{SimpleSerialClient, SimpleSerialService},
Kernel,
};
use maitake::sync::Mutex;
use mnemos_alloc::containers::{Arc, FixedVec};
use serde::{Deserialize, Serialize};
use sermux_proto::PortChunk;
use tracing::{self, debug, warn, Level};
use uuid::Uuid;
pub use sermux_proto::WellKnown;
pub struct SerialMuxService;
impl RegisteredDriver for SerialMuxService {
type Request = Request;
type Response = Response;
type Error = SerialMuxError;
type Hello = ();
type ConnectError = core::convert::Infallible;
const UUID: Uuid = crate::registry::known_uuids::kernel::SERIAL_MUX;
}
pub enum Request {
RegisterPort { port_id: u16, capacity: usize },
}
pub enum Response {
PortRegistered(PortHandle),
}
#[derive(Debug, Eq, PartialEq)]
pub enum SerialMuxError {
DuplicateItem,
RegistryFull,
}
pub struct PortHandle {
port: u16,
cons: bbq::Consumer,
outgoing: bbq::MpscProducer,
max_frame: usize,
}
pub struct SerialMuxClient {
prod: KernelHandle<SerialMuxService>,
reply: Reusable<Envelope<Result<Response, SerialMuxError>>>,
}
impl SerialMuxClient {
pub async fn from_registry(
kernel: &'static Kernel,
) -> Result<Self, registry::ConnectError<SerialMuxService>> {
let prod = kernel.registry().connect::<SerialMuxService>(()).await?;
Ok(SerialMuxClient {
prod,
reply: Reusable::new_async().await,
})
}
pub async fn from_registry_no_retry(
kernel: &'static Kernel,
) -> Result<Self, registry::ConnectError<SerialMuxService>> {
let prod = kernel
.registry()
.try_connect::<SerialMuxService>(())
.await?;
Ok(SerialMuxClient {
prod,
reply: Reusable::new_async().await,
})
}
pub async fn open_port(&mut self, port_id: u16, capacity: usize) -> Option<PortHandle> {
let resp = self
.prod
.request_oneshot(Request::RegisterPort { port_id, capacity }, &self.reply)
.await
.ok()?;
let body = resp.body.ok()?;
let Response::PortRegistered(port) = body;
Some(port)
}
}
impl PortHandle {
pub async fn open(kernel: &'static Kernel, port_id: u16, capacity: usize) -> Option<Self> {
let mut client = SerialMuxClient::from_registry(kernel).await.ok()?;
client.open_port(port_id, capacity).await
}
pub fn port(&self) -> u16 {
self.port
}
pub fn consumer(&self) -> &bbq::Consumer {
&self.cons
}
pub async fn send(&self, data: &[u8]) {
let msg_chunk = self.max_frame / 2;
for chunk in data.chunks(msg_chunk) {
let pc = PortChunk::new(self.port, chunk);
let needed = pc.buffer_required();
let mut wgr = self.outgoing.send_grant_exact(needed).await;
let used = pc
.encode_to(&mut wgr)
.expect("sermux encoding should not fail")
.len();
wgr.commit(used);
}
}
}
pub struct SerialMuxServer;
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct SerialMuxSettings {
#[serde(default)]
pub enabled: bool,
#[serde(default = "SerialMuxSettings::default_max_ports")]
pub max_ports: u16,
#[serde(default = "SerialMuxSettings::default_max_frame")]
pub max_frame: usize,
}
impl SerialMuxServer {
#[tracing::instrument(
name = "SerialMuxServer::register",
level = Level::INFO,
skip(kernel, settings),
err(Debug),
)]
pub async fn register(
kernel: &'static Kernel,
settings: SerialMuxSettings,
) -> Result<(), RegistrationError> {
let serial_handle = SimpleSerialClient::from_registry(kernel)
.await
.map_err(RegistrationError::Connect)?;
Self::register_inner(kernel, settings, serial_handle).await
}
#[tracing::instrument(
name = "SerialMuxServer::register_no_retry",
level = Level::INFO,
skip(kernel),
err(Debug),
)]
pub async fn register_no_retry(
kernel: &'static Kernel,
settings: SerialMuxSettings,
) -> Result<(), RegistrationError> {
let serial_handle = SimpleSerialClient::from_registry_no_retry(kernel)
.await
.map_err(RegistrationError::Connect)?;
Self::register_inner(kernel, settings, serial_handle).await
}
async fn register_inner(
kernel: &'static Kernel,
settings: SerialMuxSettings,
mut serial_handle: SimpleSerialClient,
) -> Result<(), RegistrationError> {
tracing::info!(?settings, "Starting SerialMuxServer");
let SerialMuxSettings {
max_ports,
max_frame,
..
} = settings;
let max_ports = max_ports as usize;
let serial_port = serial_handle
.get_port()
.await
.ok_or(RegistrationError::NoSerialPortAvailable)?;
let (sprod, scons) = serial_port.split();
let sprod = sprod.into_mpmc_producer().await;
let ports = FixedVec::new(max_ports).await;
let imutex = Arc::new(Mutex::new(MuxingInfo { ports, max_frame })).await;
let listener = kernel
.registry()
.bind_konly::<SerialMuxService>(max_ports)
.await
.map_err(|_| RegistrationError::MuxAlreadyRegistered)?;
let buf = FixedVec::new(max_frame).await;
let commander = CommanderTask {
cmd: listener.into_request_stream(max_ports).await,
out: sprod,
mux: imutex.clone(),
};
let muxer = IncomingMuxerTask {
incoming: scons,
mux: imutex,
buf,
};
kernel.spawn(commander.run()).await;
kernel
.spawn(async move {
muxer.run().await;
})
.await;
Ok(())
}
}
impl SerialMuxSettings {
pub const DEFAULT_MAX_PORTS: u16 = 16;
pub const DEFAULT_MAX_FRAME: usize = 512;
const fn default_max_ports() -> u16 {
Self::DEFAULT_MAX_PORTS
}
const fn default_max_frame() -> usize {
Self::DEFAULT_MAX_FRAME
}
pub fn with_max_ports(self, max_ports: u16) -> Self {
Self { max_ports, ..self }
}
pub fn with_max_frame(self, max_frame: usize) -> Self {
Self { max_frame, ..self }
}
}
impl Default for SerialMuxSettings {
fn default() -> Self {
Self {
enabled: true, max_ports: Self::DEFAULT_MAX_PORTS,
max_frame: Self::DEFAULT_MAX_FRAME,
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum RegistrationError {
Connect(registry::ConnectError<SimpleSerialService>),
NoSerialPortAvailable,
MuxAlreadyRegistered,
}
struct PortInfo {
port: u16,
upstream: bbq::SpscProducer,
}
struct MuxingInfo {
ports: FixedVec<PortInfo>,
max_frame: usize,
}
struct CommanderTask {
cmd: registry::listener::RequestStream<SerialMuxService>,
out: bbq::MpscProducer,
mux: Arc<Mutex<MuxingInfo>>,
}
struct IncomingMuxerTask {
buf: FixedVec<u8>,
incoming: bbq::Consumer,
mux: Arc<Mutex<MuxingInfo>>,
}
impl MuxingInfo {
async fn register_port(
&mut self,
port_id: u16,
capacity: usize,
outgoing: &bbq::MpscProducer,
) -> Result<PortHandle, SerialMuxError> {
if self.ports.is_full() {
return Err(SerialMuxError::RegistryFull);
}
if self.ports.as_slice().iter().any(|p| p.port == port_id) {
return Err(SerialMuxError::DuplicateItem);
}
let (prod, cons) = bbq::new_spsc_channel(capacity).await;
self.ports
.try_push(PortInfo {
port: port_id,
upstream: prod,
})
.map_err(|_| SerialMuxError::RegistryFull)?;
let ph = PortHandle {
port: port_id,
cons,
outgoing: outgoing.clone(),
max_frame: self.max_frame,
};
Ok(ph)
}
}
impl CommanderTask {
async fn run(self) {
loop {
let Message { msg: req, reply } = self.cmd.next_request().await;
match req.body {
Request::RegisterPort { port_id, capacity } => {
let res = {
let mut mux = self.mux.lock().await;
mux.register_port(port_id, capacity, &self.out).await
}
.map(Response::PortRegistered);
let resp = req.reply_with(res);
reply.reply_konly(resp).await.map_err(drop).unwrap();
}
}
}
}
}
impl IncomingMuxerTask {
async fn run(mut self) {
loop {
let rgr = self.incoming.read_grant().await;
if !take_from_grant(&mut self.buf, rgr) {
continue;
}
let (port_id, datab) = match try_decode(self.buf.as_slice_mut()) {
Some(a) => a,
None => {
self.buf.clear();
continue;
}
};
let mux = self.mux.lock().await;
if let Some(port) = mux.ports.as_slice().iter().find(|p| p.port == port_id) {
if let Some(mut wgr) = port.upstream.send_grant_exact_sync(datab.len()) {
wgr.copy_from_slice(datab);
wgr.commit(datab.len());
debug!(port_id, len = datab.len(), "Sent bytes to port");
} else {
warn!(port_id, len = datab.len(), "Discarded bytes, full buffer");
}
} else {
warn!(port_id, len = datab.len(), "Discarded bytes, no consumer");
}
self.buf.clear();
}
}
}
fn take_from_grant(buffer: &mut FixedVec<u8>, grant: GrantR) -> bool {
let mut try_decode = false;
let to_use = match grant.iter().position(|&v| v == 0) {
Some(idx) => {
try_decode = true;
&grant[..idx + 1]
}
None => &grant,
};
if buffer.try_extend_from_slice(to_use).is_err() {
warn!("Overfilled accumulator");
buffer.clear();
try_decode = false;
}
let used = to_use.len();
grant.release(used);
debug!(used, "consumed incoming bytes");
try_decode
}
fn try_decode(buffer: &mut [u8]) -> Option<(u16, &[u8])> {
let used = match cobs::decode_in_place(buffer) {
Ok(u) if u < 3 => {
warn!("Cobs decode too short!");
return None;
}
Ok(u) => u,
Err(_) => {
warn!("Cobs decode failed!");
return None;
}
};
let total = buffer.get(..used)?;
let mut port = [0u8; 2];
let (portb, datab) = total.split_at(2);
port.copy_from_slice(portb);
let port_id = u16::from_le_bytes(port);
Some((port_id, datab))
}
#[cfg(test)]
mod test {
use super::*;
use crate::comms::bbq::{Consumer, SpscProducer};
use core::ops::Deref;
struct Stuff {
prod: SpscProducer,
cons: Consumer,
buffer: FixedVec<u8>,
}
impl Stuff {
fn setup() -> Self {
let (prod, cons) =
futures::executor::block_on(async { bbq::new_spsc_channel(128).await });
let buffer = futures::executor::block_on(async { FixedVec::<u8>::new(64).await });
Stuff { prod, cons, buffer }
}
fn send(&self, data: &[u8]) {
let mut wgr = self.prod.send_grant_exact_sync(data.len()).unwrap();
wgr.copy_from_slice(data);
wgr.commit(data.len());
}
fn read(&self) -> GrantR {
self.cons.read_grant_sync().unwrap()
}
fn clear(&mut self) {
self.buffer.clear();
}
}
#[test]
fn simple_decode() {
const MESSAGE: &[u8] = &[0x01, 0x01, 0x02, b'!', 0x00];
let mut ctxt = Stuff::setup();
ctxt.send(MESSAGE);
let rgr = ctxt.read();
assert!(take_from_grant(&mut ctxt.buffer, rgr));
assert_eq!(ctxt.buffer.as_slice(), MESSAGE);
let (port_id, data) = try_decode(ctxt.buffer.as_slice_mut()).unwrap();
assert_eq!(port_id, 0);
assert_eq!(data, b"!");
}
#[test]
fn empty_message() {
const MESSAGE: &[u8] = &[0x01, 0x01, 0x01, 0x00];
let mut ctxt = Stuff::setup();
ctxt.send(MESSAGE);
let rgr = ctxt.read();
assert!(take_from_grant(&mut ctxt.buffer, rgr));
assert_eq!(ctxt.buffer.as_slice(), MESSAGE);
assert!(try_decode(ctxt.buffer.as_slice_mut()).is_none());
}
#[test]
fn fillup() {
const MESSAGE_GOOD: &[u8] = &[0x01, 0x01, 0x02, b'!', 0x00];
const MESSAGE_BAD: &[u8] = &[0x01, 0x01, 0x02, b'!'];
let mut ctxt = Stuff::setup();
let times = ctxt.buffer.capacity() / MESSAGE_BAD.len();
for _ in 0..times {
ctxt.send(MESSAGE_BAD);
let rgr = ctxt.read();
assert!(!take_from_grant(&mut ctxt.buffer, rgr));
assert!(!ctxt.buffer.is_empty());
}
ctxt.send(MESSAGE_BAD);
let rgr = ctxt.read();
assert!(!take_from_grant(&mut ctxt.buffer, rgr));
assert!(ctxt.buffer.is_empty());
ctxt.send(MESSAGE_GOOD);
let rgr = ctxt.read();
assert!(take_from_grant(&mut ctxt.buffer, rgr));
assert_eq!(ctxt.buffer.as_slice(), MESSAGE_GOOD);
let (port_id, data) = try_decode(ctxt.buffer.as_slice_mut()).unwrap();
assert_eq!(port_id, 0);
assert_eq!(data, b"!");
}
#[test]
fn partial_take() {
const MESSAGE: &[u8] = &[0x01, 0x01, 0x02, b'!', 0x00];
let mut ctxt = Stuff::setup();
ctxt.send(MESSAGE);
ctxt.send(MESSAGE);
let rgr = ctxt.read();
assert!(take_from_grant(&mut ctxt.buffer, rgr));
assert_eq!(ctxt.buffer.as_slice(), MESSAGE);
let (port_id, data) = try_decode(ctxt.buffer.as_slice_mut()).unwrap();
assert_eq!(port_id, 0);
assert_eq!(data, b"!");
ctxt.clear();
let rgr = ctxt.read();
assert_eq!(rgr.deref(), MESSAGE);
assert!(take_from_grant(&mut ctxt.buffer, rgr));
assert_eq!(ctxt.buffer.as_slice(), MESSAGE);
let (port_id, data) = try_decode(ctxt.buffer.as_slice_mut()).unwrap();
assert_eq!(port_id, 0);
assert_eq!(data, b"!");
ctxt.clear();
}
}