add back client timeouts
authoralfadur
Tue, 01 Feb 2022 20:58:35 +0300
changeset 15832 a4d505a32879
parent 15831 7d0f747afcb8
child 15833 3511bacbd763
add back client timeouts
rust/hedgewars-server/src/protocol.rs
rust/hedgewars-server/src/server/network.rs
--- a/rust/hedgewars-server/src/protocol.rs	Tue Feb 01 02:23:35 2022 +0300
+++ b/rust/hedgewars-server/src/protocol.rs	Tue Feb 01 20:58:35 2022 +0300
@@ -1,22 +1,62 @@
 use bytes::{Buf, BufMut, BytesMut};
 use log::*;
-use std::{io, io::ErrorKind, marker::Unpin};
-use tokio::io::AsyncReadExt;
+use std::{
+    error::Error,
+    fmt::{Debug, Display, Formatter},
+    io,
+    io::ErrorKind,
+    marker::Unpin,
+    time::Duration,
+};
+use tokio::{io::AsyncReadExt, time::timeout};
 
+use crate::protocol::ProtocolError::Timeout;
 use hedgewars_network_protocol::{
     messages::HwProtocolMessage,
+    parser::HwProtocolError,
     parser::{malformed_message, message},
 };
 
+#[derive(Debug)]
+pub enum ProtocolError {
+    Eof,
+    Timeout,
+    Network(Box<dyn Error + Send>),
+}
+
+impl Display for ProtocolError {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            ProtocolError::Eof => write!(f, "Connection reset by peer"),
+            ProtocolError::Timeout => write!(f, "Read operation timed out"),
+            ProtocolError::Network(source) => write!(f, "{:?}", source),
+        }
+    }
+}
+
+impl Error for ProtocolError {
+    fn source(&self) -> Option<&(dyn Error + 'static)> {
+        if let Self::Network(source) = self {
+            Some(source.as_ref())
+        } else {
+            None
+        }
+    }
+}
+
+pub type Result<T> = std::result::Result<T, ProtocolError>;
+
 pub struct ProtocolDecoder {
     buffer: BytesMut,
+    read_timeout: Duration,
     is_recovering: bool,
 }
 
 impl ProtocolDecoder {
-    pub fn new() -> ProtocolDecoder {
+    pub fn new(read_timeout: Duration) -> ProtocolDecoder {
         ProtocolDecoder {
             buffer: BytesMut::with_capacity(1024),
+            read_timeout,
             is_recovering: false,
         }
     }
@@ -57,17 +97,21 @@
     pub async fn read_from<R: AsyncReadExt + Unpin>(
         &mut self,
         stream: &mut R,
-    ) -> Option<HwProtocolMessage> {
+    ) -> Result<HwProtocolMessage> {
+        use ProtocolError::*;
+
         loop {
             if !self.buffer.has_remaining() {
-                let count = stream.read_buf(&mut self.buffer).await.ok()?;
-                if count == 0 {
-                    return None;
-                }
+                match timeout(self.read_timeout, stream.read_buf(&mut self.buffer)).await {
+                    Err(_) => return Err(Timeout),
+                    Ok(Err(e)) => return Err(Network(Box::new(e))),
+                    Ok(Ok(0)) => return Err(Eof),
+                    Ok(Ok(_)) => (),
+                };
             }
             while !self.buffer.is_empty() {
                 if let Some(result) = self.extract_message() {
-                    return Some(result);
+                    return Ok(result);
                 }
             }
         }
--- a/rust/hedgewars-server/src/server/network.rs	Tue Feb 01 02:23:35 2022 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Tue Feb 01 20:58:35 2022 +0300
@@ -2,15 +2,9 @@
 use log::*;
 use slab::Slab;
 use std::{
-    collections::HashSet,
-    io,
-    io::{Error, ErrorKind, Read, Write},
     iter::Iterator,
-    mem::{replace, swap},
-    net::{IpAddr, Ipv4Addr, SocketAddr},
-    num::NonZeroU32,
+    net::{IpAddr, SocketAddr},
     time::Duration,
-    time::Instant,
 };
 use tokio::{
     io::AsyncReadExt,
@@ -25,7 +19,7 @@
     },
     handlers,
     handlers::{IoResult, IoTask, ServerState},
-    protocol::ProtocolDecoder,
+    protocol::{self, ProtocolDecoder, ProtocolError},
     utils,
 };
 use hedgewars_network_protocol::{
@@ -33,6 +27,8 @@
 };
 use tokio::io::AsyncWriteExt;
 
+const PING_TIMEOUT: Duration = Duration::from_secs(15);
+
 enum ClientUpdateData {
     Message(HwProtocolMessage),
     Error(String),
@@ -80,16 +76,28 @@
             socket,
             peer_addr,
             receiver,
-            decoder: ProtocolDecoder::new(),
+            decoder: ProtocolDecoder::new(PING_TIMEOUT),
         }
     }
 
-    async fn read(&mut self) -> Option<HwProtocolMessage> {
-        self.decoder.read_from(&mut self.socket).await
+    async fn read(
+        socket: &mut TcpStream,
+        decoder: &mut ProtocolDecoder,
+    ) -> protocol::Result<HwProtocolMessage> {
+        let result = decoder.read_from(socket).await;
+        if matches!(result, Err(ProtocolError::Timeout)) {
+            if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
+                decoder.read_from(socket).await
+            } else {
+                Err(ProtocolError::Eof)
+            }
+        } else {
+            result
+        }
     }
 
-    async fn write(&mut self, mut data: Bytes) -> bool {
-        !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0)
+    async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool {
+        !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0)
     }
 
     async fn run(mut self, sender: Sender<ClientUpdate>) {
@@ -103,7 +111,7 @@
             tokio::select! {
                 server_message = self.receiver.recv() => {
                     match server_message {
-                        Some(message) => if !self.write(message).await {
+                        Some(message) => if !Self::write(&mut self.socket, message).await {
                             sender.send(Error("Connection reset by peer".to_string())).await;
                             break;
                         }
@@ -112,15 +120,18 @@
                         }
                     }
                 }
-                client_message = self.decoder.read_from(&mut self.socket) => {
+                client_message = Self::read(&mut self.socket, &mut self.decoder) => {
                      match client_message {
-                        Some(message) => {
+                        Ok(message) => {
                             if !sender.send(Message(message)).await {
                                 break;
                             }
                         }
-                        None => {
-                            sender.send(Error("Connection reset by peer".to_string())).await;
+                        Err(e) => {
+                            sender.send(Error(format!("{}", e))).await;
+                            if matches!(e, ProtocolError::Timeout) {
+                                Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
+                            }
                             break;
                         }
                     }
@@ -153,8 +164,10 @@
                         Some(ClientUpdate{ client_id, data: Message(message) } ) => {
                             self.handle_message(client_id, message).await;
                         }
-                        Some(ClientUpdate{ client_id, .. } ) => {
+                        Some(ClientUpdate{ client_id, data: Error(e) } ) => {
                             let mut response = handlers::Response::new(client_id);
+                            info!("Client {} error: {:?}", client_id, e);
+                            response.remove_client(client_id);
                             handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
                             self.handle_response(response).await;
                         }