--- a/rust/hedgewars-server/src/protocol.rs Tue Feb 01 02:23:35 2022 +0300
+++ b/rust/hedgewars-server/src/protocol.rs Tue Feb 01 20:58:35 2022 +0300
@@ -1,22 +1,62 @@
use bytes::{Buf, BufMut, BytesMut};
use log::*;
-use std::{io, io::ErrorKind, marker::Unpin};
-use tokio::io::AsyncReadExt;
+use std::{
+ error::Error,
+ fmt::{Debug, Display, Formatter},
+ io,
+ io::ErrorKind,
+ marker::Unpin,
+ time::Duration,
+};
+use tokio::{io::AsyncReadExt, time::timeout};
+use crate::protocol::ProtocolError::Timeout;
use hedgewars_network_protocol::{
messages::HwProtocolMessage,
+ parser::HwProtocolError,
parser::{malformed_message, message},
};
+#[derive(Debug)]
+pub enum ProtocolError {
+ Eof,
+ Timeout,
+ Network(Box<dyn Error + Send>),
+}
+
+impl Display for ProtocolError {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ProtocolError::Eof => write!(f, "Connection reset by peer"),
+ ProtocolError::Timeout => write!(f, "Read operation timed out"),
+ ProtocolError::Network(source) => write!(f, "{:?}", source),
+ }
+ }
+}
+
+impl Error for ProtocolError {
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ if let Self::Network(source) = self {
+ Some(source.as_ref())
+ } else {
+ None
+ }
+ }
+}
+
+pub type Result<T> = std::result::Result<T, ProtocolError>;
+
pub struct ProtocolDecoder {
buffer: BytesMut,
+ read_timeout: Duration,
is_recovering: bool,
}
impl ProtocolDecoder {
- pub fn new() -> ProtocolDecoder {
+ pub fn new(read_timeout: Duration) -> ProtocolDecoder {
ProtocolDecoder {
buffer: BytesMut::with_capacity(1024),
+ read_timeout,
is_recovering: false,
}
}
@@ -57,17 +97,21 @@
pub async fn read_from<R: AsyncReadExt + Unpin>(
&mut self,
stream: &mut R,
- ) -> Option<HwProtocolMessage> {
+ ) -> Result<HwProtocolMessage> {
+ use ProtocolError::*;
+
loop {
if !self.buffer.has_remaining() {
- let count = stream.read_buf(&mut self.buffer).await.ok()?;
- if count == 0 {
- return None;
- }
+ match timeout(self.read_timeout, stream.read_buf(&mut self.buffer)).await {
+ Err(_) => return Err(Timeout),
+ Ok(Err(e)) => return Err(Network(Box::new(e))),
+ Ok(Ok(0)) => return Err(Eof),
+ Ok(Ok(_)) => (),
+ };
}
while !self.buffer.is_empty() {
if let Some(result) = self.extract_message() {
- return Some(result);
+ return Ok(result);
}
}
}
--- a/rust/hedgewars-server/src/server/network.rs Tue Feb 01 02:23:35 2022 +0300
+++ b/rust/hedgewars-server/src/server/network.rs Tue Feb 01 20:58:35 2022 +0300
@@ -2,15 +2,9 @@
use log::*;
use slab::Slab;
use std::{
- collections::HashSet,
- io,
- io::{Error, ErrorKind, Read, Write},
iter::Iterator,
- mem::{replace, swap},
- net::{IpAddr, Ipv4Addr, SocketAddr},
- num::NonZeroU32,
+ net::{IpAddr, SocketAddr},
time::Duration,
- time::Instant,
};
use tokio::{
io::AsyncReadExt,
@@ -25,7 +19,7 @@
},
handlers,
handlers::{IoResult, IoTask, ServerState},
- protocol::ProtocolDecoder,
+ protocol::{self, ProtocolDecoder, ProtocolError},
utils,
};
use hedgewars_network_protocol::{
@@ -33,6 +27,8 @@
};
use tokio::io::AsyncWriteExt;
+const PING_TIMEOUT: Duration = Duration::from_secs(15);
+
enum ClientUpdateData {
Message(HwProtocolMessage),
Error(String),
@@ -80,16 +76,28 @@
socket,
peer_addr,
receiver,
- decoder: ProtocolDecoder::new(),
+ decoder: ProtocolDecoder::new(PING_TIMEOUT),
}
}
- async fn read(&mut self) -> Option<HwProtocolMessage> {
- self.decoder.read_from(&mut self.socket).await
+ async fn read(
+ socket: &mut TcpStream,
+ decoder: &mut ProtocolDecoder,
+ ) -> protocol::Result<HwProtocolMessage> {
+ let result = decoder.read_from(socket).await;
+ if matches!(result, Err(ProtocolError::Timeout)) {
+ if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
+ decoder.read_from(socket).await
+ } else {
+ Err(ProtocolError::Eof)
+ }
+ } else {
+ result
+ }
}
- async fn write(&mut self, mut data: Bytes) -> bool {
- !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0)
+ async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool {
+ !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0)
}
async fn run(mut self, sender: Sender<ClientUpdate>) {
@@ -103,7 +111,7 @@
tokio::select! {
server_message = self.receiver.recv() => {
match server_message {
- Some(message) => if !self.write(message).await {
+ Some(message) => if !Self::write(&mut self.socket, message).await {
sender.send(Error("Connection reset by peer".to_string())).await;
break;
}
@@ -112,15 +120,18 @@
}
}
}
- client_message = self.decoder.read_from(&mut self.socket) => {
+ client_message = Self::read(&mut self.socket, &mut self.decoder) => {
match client_message {
- Some(message) => {
+ Ok(message) => {
if !sender.send(Message(message)).await {
break;
}
}
- None => {
- sender.send(Error("Connection reset by peer".to_string())).await;
+ Err(e) => {
+ sender.send(Error(format!("{}", e))).await;
+ if matches!(e, ProtocolError::Timeout) {
+ Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
+ }
break;
}
}
@@ -153,8 +164,10 @@
Some(ClientUpdate{ client_id, data: Message(message) } ) => {
self.handle_message(client_id, message).await;
}
- Some(ClientUpdate{ client_id, .. } ) => {
+ Some(ClientUpdate{ client_id, data: Error(e) } ) => {
let mut response = handlers::Response::new(client_id);
+ info!("Client {} error: {:?}", client_id, e);
+ response.remove_client(client_id);
handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
self.handle_response(response).await;
}