# HG changeset patch # User alfadur # Date 1643738315 -10800 # Node ID a4d505a32879172fddebd2362678747df053f776 # Parent 7d0f747afcb81744f67a79f868bc8937aa54ad38 add back client timeouts diff -r 7d0f747afcb8 -r a4d505a32879 rust/hedgewars-server/src/protocol.rs --- a/rust/hedgewars-server/src/protocol.rs Tue Feb 01 02:23:35 2022 +0300 +++ b/rust/hedgewars-server/src/protocol.rs Tue Feb 01 20:58:35 2022 +0300 @@ -1,22 +1,62 @@ use bytes::{Buf, BufMut, BytesMut}; use log::*; -use std::{io, io::ErrorKind, marker::Unpin}; -use tokio::io::AsyncReadExt; +use std::{ + error::Error, + fmt::{Debug, Display, Formatter}, + io, + io::ErrorKind, + marker::Unpin, + time::Duration, +}; +use tokio::{io::AsyncReadExt, time::timeout}; +use crate::protocol::ProtocolError::Timeout; use hedgewars_network_protocol::{ messages::HwProtocolMessage, + parser::HwProtocolError, parser::{malformed_message, message}, }; +#[derive(Debug)] +pub enum ProtocolError { + Eof, + Timeout, + Network(Box), +} + +impl Display for ProtocolError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ProtocolError::Eof => write!(f, "Connection reset by peer"), + ProtocolError::Timeout => write!(f, "Read operation timed out"), + ProtocolError::Network(source) => write!(f, "{:?}", source), + } + } +} + +impl Error for ProtocolError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let Self::Network(source) = self { + Some(source.as_ref()) + } else { + None + } + } +} + +pub type Result = std::result::Result; + pub struct ProtocolDecoder { buffer: BytesMut, + read_timeout: Duration, is_recovering: bool, } impl ProtocolDecoder { - pub fn new() -> ProtocolDecoder { + pub fn new(read_timeout: Duration) -> ProtocolDecoder { ProtocolDecoder { buffer: BytesMut::with_capacity(1024), + read_timeout, is_recovering: false, } } @@ -57,17 +97,21 @@ pub async fn read_from( &mut self, stream: &mut R, - ) -> Option { + ) -> Result { + use ProtocolError::*; + loop { if !self.buffer.has_remaining() { - let count = stream.read_buf(&mut self.buffer).await.ok()?; - if count == 0 { - return None; - } + match timeout(self.read_timeout, stream.read_buf(&mut self.buffer)).await { + Err(_) => return Err(Timeout), + Ok(Err(e)) => return Err(Network(Box::new(e))), + Ok(Ok(0)) => return Err(Eof), + Ok(Ok(_)) => (), + }; } while !self.buffer.is_empty() { if let Some(result) = self.extract_message() { - return Some(result); + return Ok(result); } } } diff -r 7d0f747afcb8 -r a4d505a32879 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Tue Feb 01 02:23:35 2022 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Tue Feb 01 20:58:35 2022 +0300 @@ -2,15 +2,9 @@ use log::*; use slab::Slab; use std::{ - collections::HashSet, - io, - io::{Error, ErrorKind, Read, Write}, iter::Iterator, - mem::{replace, swap}, - net::{IpAddr, Ipv4Addr, SocketAddr}, - num::NonZeroU32, + net::{IpAddr, SocketAddr}, time::Duration, - time::Instant, }; use tokio::{ io::AsyncReadExt, @@ -25,7 +19,7 @@ }, handlers, handlers::{IoResult, IoTask, ServerState}, - protocol::ProtocolDecoder, + protocol::{self, ProtocolDecoder, ProtocolError}, utils, }; use hedgewars_network_protocol::{ @@ -33,6 +27,8 @@ }; use tokio::io::AsyncWriteExt; +const PING_TIMEOUT: Duration = Duration::from_secs(15); + enum ClientUpdateData { Message(HwProtocolMessage), Error(String), @@ -80,16 +76,28 @@ socket, peer_addr, receiver, - decoder: ProtocolDecoder::new(), + decoder: ProtocolDecoder::new(PING_TIMEOUT), } } - async fn read(&mut self) -> Option { - self.decoder.read_from(&mut self.socket).await + async fn read( + socket: &mut TcpStream, + decoder: &mut ProtocolDecoder, + ) -> protocol::Result { + let result = decoder.read_from(socket).await; + if matches!(result, Err(ProtocolError::Timeout)) { + if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { + decoder.read_from(socket).await + } else { + Err(ProtocolError::Eof) + } + } else { + result + } } - async fn write(&mut self, mut data: Bytes) -> bool { - !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0) + async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool { + !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0) } async fn run(mut self, sender: Sender) { @@ -103,7 +111,7 @@ tokio::select! { server_message = self.receiver.recv() => { match server_message { - Some(message) => if !self.write(message).await { + Some(message) => if !Self::write(&mut self.socket, message).await { sender.send(Error("Connection reset by peer".to_string())).await; break; } @@ -112,15 +120,18 @@ } } } - client_message = self.decoder.read_from(&mut self.socket) => { + client_message = Self::read(&mut self.socket, &mut self.decoder) => { match client_message { - Some(message) => { + Ok(message) => { if !sender.send(Message(message)).await { break; } } - None => { - sender.send(Error("Connection reset by peer".to_string())).await; + Err(e) => { + sender.send(Error(format!("{}", e))).await; + if matches!(e, ProtocolError::Timeout) { + Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; + } break; } } @@ -153,8 +164,10 @@ Some(ClientUpdate{ client_id, data: Message(message) } ) => { self.handle_message(client_id, message).await; } - Some(ClientUpdate{ client_id, .. } ) => { + Some(ClientUpdate{ client_id, data: Error(e) } ) => { let mut response = handlers::Response::new(client_id); + info!("Client {} error: {:?}", client_id, e); + response.remove_client(client_id); handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); self.handle_response(response).await; }