rust/hedgewars-server/src/server/network.rs
changeset 14803 92225a708bda
parent 14796 f5d43f007970
child 14807 b2beb784e4b5
--- a/rust/hedgewars-server/src/server/network.rs	Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Mon Apr 15 21:22:51 2019 +0300
@@ -13,6 +13,7 @@
     net::{TcpListener, TcpStream},
     Poll, PollOpt, Ready, Token,
 };
+use mio_extras::timer;
 use netbuf;
 use slab::Slab;
 
@@ -34,8 +35,11 @@
         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
     },
 };
+use std::time::Duration;
 
 const MAX_BYTES_PER_READ: usize = 2048;
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
 
 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
 pub enum NetworkClientState {
@@ -80,16 +84,23 @@
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf,
+    timeout: timer::Timeout,
 }
 
 impl NetworkClient {
-    pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
+    pub fn new(
+        id: ClientId,
+        socket: ClientSocket,
+        peer_addr: SocketAddr,
+        timeout: timer::Timeout,
+    ) -> NetworkClient {
         NetworkClient {
             id,
             socket,
             peer_addr,
             decoder: ProtocolDecoder::new(),
             buf_out: netbuf::Buf::new(),
+            timeout,
         }
     }
 
@@ -231,6 +242,10 @@
     pub fn send_string(&mut self, msg: &str) {
         self.send_raw_msg(&msg.as_bytes());
     }
+
+    pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+        replace(&mut self.timeout, timeout)
+    }
 }
 
 #[cfg(feature = "tls-connections")]
@@ -288,6 +303,13 @@
     }
 }
 
+enum TimeoutEvent {
+    SendPing,
+    DropClient,
+}
+
+struct TimerData(TimeoutEvent, ClientId);
+
 pub struct NetworkLayer {
     listener: TcpListener,
     server: HWServer,
@@ -298,6 +320,21 @@
     ssl: ServerSsl,
     #[cfg(feature = "official-server")]
     io: IoLayer,
+    timer: timer::Timer<TimerData>,
+}
+
+fn create_ping_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+    timer.set_timeout(
+        SEND_PING_TIMEOUT,
+        TimerData(TimeoutEvent::SendPing, client_id),
+    )
+}
+
+fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+    timer.set_timeout(
+        DROP_CLIENT_TIMEOUT,
+        TimerData(TimeoutEvent::DropClient, client_id),
+    )
 }
 
 impl NetworkLayer {
@@ -306,6 +343,7 @@
         let clients = Slab::with_capacity(clients_limit);
         let pending = HashSet::with_capacity(2 * clients_limit);
         let pending_cache = Vec::with_capacity(2 * clients_limit);
+        let timer = timer::Builder::default().build();
 
         NetworkLayer {
             listener,
@@ -317,6 +355,7 @@
             ssl: NetworkLayer::create_ssl_context(),
             #[cfg(feature = "official-server")]
             io: IoLayer::new(),
+            timer,
         }
     }
 
@@ -346,6 +385,13 @@
             PollOpt::edge(),
         )?;
 
+        poll.register(
+            &self.timer,
+            utils::TIMER_TOKEN,
+            Ready::readable(),
+            PollOpt::edge(),
+        )?;
+
         #[cfg(feature = "official-server")]
         self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
 
@@ -384,7 +430,12 @@
         )
         .expect("could not register socket with event loop");
 
-        let client = NetworkClient::new(client_id, client_socket, addr);
+        let client = NetworkClient::new(
+            client_id,
+            client_socket,
+            addr,
+            create_ping_timeout(&mut self.timer, client_id),
+        );
         info!("client {} ({}) added", client.id, client.peer_addr);
         entry.insert(client);
 
@@ -419,6 +470,29 @@
         }
     }
 
+    pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
+        while let Some(TimerData(event, client_id)) = self.timer.poll() {
+            match event {
+                TimeoutEvent::SendPing => {
+                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
+                        client.send_string(&HWServerMessage::Ping.to_raw_protocol());
+                        client.write()?;
+                        client.replace_timeout(create_drop_timeout(&mut self.timer, client_id));
+                    }
+                }
+                TimeoutEvent::DropClient => {
+                    self.operation_failed(
+                        poll,
+                        client_id,
+                        &ErrorKind::TimedOut.into(),
+                        "No ping response",
+                    )?;
+                }
+            }
+        }
+        Ok(())
+    }
+
     #[cfg(feature = "official-server")]
     pub fn handle_io_result(&mut self) {
         if let Some((client_id, result)) = self.io.try_recv() {
@@ -486,6 +560,8 @@
 
     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+            let timeout = client.replace_timeout(create_ping_timeout(&mut self.timer, client_id));
+            self.timer.cancel_timeout(&timeout);
             client.read()
         } else {
             warn!("invalid readable client: {}", client_id);