add client timeouts
authoralfadur
Mon, 15 Apr 2019 21:22:51 +0300
changeset 14824 92225a708bda
parent 14823 a40139603cde
child 14825 d7b40d61729f
add client timeouts
rust/hedgewars-server/src/main.rs
rust/hedgewars-server/src/server/network.rs
rust/hedgewars-server/src/utils.rs
--- a/rust/hedgewars-server/src/main.rs	Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/main.rs	Mon Apr 15 21:22:51 2019 +0300
@@ -68,6 +68,7 @@
             if event.readiness() & Ready::readable() == Ready::readable() {
                 match event.token() {
                     utils::SERVER_TOKEN => hw_network.accept_client(&poll).unwrap(),
+                    utils::TIMER_TOKEN => hw_network.handle_timeout(&poll).unwrap(),
                     #[cfg(feature = "official-server")]
                     utils::IO_TOKEN => hw_network.handle_io_result(),
                     Token(tok) => hw_network.client_readable(&poll, tok).unwrap(),
@@ -75,8 +76,7 @@
             }
             if event.readiness() & Ready::writable() == Ready::writable() {
                 match event.token() {
-                    utils::SERVER_TOKEN => unreachable!(),
-                    utils::IO_TOKEN => unreachable!(),
+                    utils::SERVER_TOKEN | utils::TIMER_TOKEN | utils::IO_TOKEN => unreachable!(),
                     Token(tok) => hw_network.client_writable(&poll, tok).unwrap(),
                 }
             }
--- a/rust/hedgewars-server/src/server/network.rs	Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Mon Apr 15 21:22:51 2019 +0300
@@ -13,6 +13,7 @@
     net::{TcpListener, TcpStream},
     Poll, PollOpt, Ready, Token,
 };
+use mio_extras::timer;
 use netbuf;
 use slab::Slab;
 
@@ -34,8 +35,11 @@
         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
     },
 };
+use std::time::Duration;
 
 const MAX_BYTES_PER_READ: usize = 2048;
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
 
 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
 pub enum NetworkClientState {
@@ -80,16 +84,23 @@
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf,
+    timeout: timer::Timeout,
 }
 
 impl NetworkClient {
-    pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
+    pub fn new(
+        id: ClientId,
+        socket: ClientSocket,
+        peer_addr: SocketAddr,
+        timeout: timer::Timeout,
+    ) -> NetworkClient {
         NetworkClient {
             id,
             socket,
             peer_addr,
             decoder: ProtocolDecoder::new(),
             buf_out: netbuf::Buf::new(),
+            timeout,
         }
     }
 
@@ -231,6 +242,10 @@
     pub fn send_string(&mut self, msg: &str) {
         self.send_raw_msg(&msg.as_bytes());
     }
+
+    pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+        replace(&mut self.timeout, timeout)
+    }
 }
 
 #[cfg(feature = "tls-connections")]
@@ -288,6 +303,13 @@
     }
 }
 
+enum TimeoutEvent {
+    SendPing,
+    DropClient,
+}
+
+struct TimerData(TimeoutEvent, ClientId);
+
 pub struct NetworkLayer {
     listener: TcpListener,
     server: HWServer,
@@ -298,6 +320,21 @@
     ssl: ServerSsl,
     #[cfg(feature = "official-server")]
     io: IoLayer,
+    timer: timer::Timer<TimerData>,
+}
+
+fn create_ping_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+    timer.set_timeout(
+        SEND_PING_TIMEOUT,
+        TimerData(TimeoutEvent::SendPing, client_id),
+    )
+}
+
+fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+    timer.set_timeout(
+        DROP_CLIENT_TIMEOUT,
+        TimerData(TimeoutEvent::DropClient, client_id),
+    )
 }
 
 impl NetworkLayer {
@@ -306,6 +343,7 @@
         let clients = Slab::with_capacity(clients_limit);
         let pending = HashSet::with_capacity(2 * clients_limit);
         let pending_cache = Vec::with_capacity(2 * clients_limit);
+        let timer = timer::Builder::default().build();
 
         NetworkLayer {
             listener,
@@ -317,6 +355,7 @@
             ssl: NetworkLayer::create_ssl_context(),
             #[cfg(feature = "official-server")]
             io: IoLayer::new(),
+            timer,
         }
     }
 
@@ -346,6 +385,13 @@
             PollOpt::edge(),
         )?;
 
+        poll.register(
+            &self.timer,
+            utils::TIMER_TOKEN,
+            Ready::readable(),
+            PollOpt::edge(),
+        )?;
+
         #[cfg(feature = "official-server")]
         self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
 
@@ -384,7 +430,12 @@
         )
         .expect("could not register socket with event loop");
 
-        let client = NetworkClient::new(client_id, client_socket, addr);
+        let client = NetworkClient::new(
+            client_id,
+            client_socket,
+            addr,
+            create_ping_timeout(&mut self.timer, client_id),
+        );
         info!("client {} ({}) added", client.id, client.peer_addr);
         entry.insert(client);
 
@@ -419,6 +470,29 @@
         }
     }
 
+    pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
+        while let Some(TimerData(event, client_id)) = self.timer.poll() {
+            match event {
+                TimeoutEvent::SendPing => {
+                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
+                        client.send_string(&HWServerMessage::Ping.to_raw_protocol());
+                        client.write()?;
+                        client.replace_timeout(create_drop_timeout(&mut self.timer, client_id));
+                    }
+                }
+                TimeoutEvent::DropClient => {
+                    self.operation_failed(
+                        poll,
+                        client_id,
+                        &ErrorKind::TimedOut.into(),
+                        "No ping response",
+                    )?;
+                }
+            }
+        }
+        Ok(())
+    }
+
     #[cfg(feature = "official-server")]
     pub fn handle_io_result(&mut self) {
         if let Some((client_id, result)) = self.io.try_recv() {
@@ -486,6 +560,8 @@
 
     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+            let timeout = client.replace_timeout(create_ping_timeout(&mut self.timer, client_id));
+            self.timer.cancel_timeout(&timeout);
             client.read()
         } else {
             warn!("invalid readable client: {}", client_id);
--- a/rust/hedgewars-server/src/utils.rs	Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/utils.rs	Mon Apr 15 21:22:51 2019 +0300
@@ -4,7 +4,8 @@
 
 pub const SERVER_VERSION: u32 = 3;
 pub const SERVER_TOKEN: mio::Token = mio::Token(1_000_000_000);
-pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_001);
+pub const TIMER_TOKEN: mio::Token = mio::Token(1_000_000_001);
+pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_002);
 
 pub fn is_name_illegal(name: &str) -> bool {
     name.len() > 40