rust/hedgewars-server/src/protocol.rs
author alfadur
Thu, 28 Mar 2024 00:25:22 +0300
changeset 16004 2871fc57a282
parent 16000 d9f1b239b6d7
permissions -rw-r--r--
fix early room deletion

use bytes::{Buf, BufMut, BytesMut};
use log::*;
use std::{
    error::Error,
    fmt::{Debug, Display, Formatter},
    io,
    io::ErrorKind,
    marker::Unpin,
    time::Duration,
};
use tokio::{io::AsyncReadExt, time::timeout};

use crate::protocol::ProtocolError::Timeout;
use hedgewars_network_protocol::{
    messages::HwProtocolMessage,
    parser::HwProtocolError,
    parser::{malformed_message, message},
};

#[derive(Debug)]
pub enum ProtocolError {
    Eof,
    Timeout,
    Network(Box<dyn Error + Send>),
}

impl Display for ProtocolError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            ProtocolError::Eof => write!(f, "Connection reset by peer"),
            ProtocolError::Timeout => write!(f, "Read operation timed out"),
            ProtocolError::Network(source) => write!(f, "{:?}", source),
        }
    }
}

impl Error for ProtocolError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        if let Self::Network(source) = self {
            Some(source.as_ref())
        } else {
            None
        }
    }
}

pub type Result<T> = std::result::Result<T, ProtocolError>;

pub struct ProtocolDecoder {
    buffer: BytesMut,
    read_timeout: Duration,
    is_recovering: bool,
}

impl ProtocolDecoder {
    pub fn new(read_timeout: Duration) -> ProtocolDecoder {
        ProtocolDecoder {
            buffer: BytesMut::with_capacity(1024),
            read_timeout,
            is_recovering: false,
        }
    }

    fn recover(&mut self) -> bool {
        self.is_recovering = match malformed_message(&self.buffer[..]) {
            Ok((tail, ())) => {
                let remaining = tail.len();
                self.buffer.advance(self.buffer.len() - remaining);
                false
            }
            _ => {
                self.buffer.clear();
                true
            }
        };
        !self.is_recovering
    }

    fn extract_message(&mut self) -> Option<HwProtocolMessage> {
        if !self.is_recovering || self.recover() {
            match message(&self.buffer[..]) {
                Ok((tail, message)) => {
                    let remaining = tail.len();
                    self.buffer.advance(self.buffer.len() - remaining);
                    return Some(message);
                }
                Err(nom::Err::Incomplete(_)) => {}
                Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
                    debug!("Invalid message: {:?}", e);
                    trace!(
                        "Buffer content: {:?}",
                        String::from_utf8_lossy(&self.buffer[..])
                    );
                    self.recover();
                }
            }
        }
        None
    }

    pub async fn read_from<R: AsyncReadExt + Unpin>(
        &mut self,
        stream: &mut R,
    ) -> Result<HwProtocolMessage> {
        use ProtocolError::*;

        loop {
            let remaining = self.buffer.capacity() - self.buffer.len();
            if remaining < 1024 {
                self.buffer.reserve(2048 - remaining);
            }

            if !self.buffer.has_remaining() || self.is_recovering {
                //todo!("ensure the buffer doesn't grow indefinitely")
                match timeout(self.read_timeout, stream.read_buf(&mut self.buffer)).await {
                    Err(_) => return Err(Timeout),
                    Ok(Err(e)) => return Err(Network(Box::new(e))),
                    Ok(Ok(0)) => return Err(Eof),
                    Ok(Ok(_)) => (),
                };
            }
            while !self.buffer.is_empty() {
                if let Some(result) = self.extract_message() {
                    return Ok(result);
                }
            }
        }
    }
}