rust/hedgewars-server/src/server/network.rs
changeset 15822 6af892a0a4b8
parent 15553 ede5f4ec48f3
child 15826 747278149393
--- a/rust/hedgewars-server/src/server/network.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -6,19 +6,25 @@
     io::{Error, ErrorKind, Read, Write},
     mem::{replace, swap},
     net::{IpAddr, Ipv4Addr, SocketAddr},
+    num::NonZeroU32,
+    time::Duration,
+    time::Instant,
 };
 
 use log::*;
 use mio::{
+    event::Source,
     net::{TcpListener, TcpStream},
-    Evented, Poll, PollOpt, Ready, Token,
+    Interest, Poll, Token, Waker,
 };
-use mio_extras::timer;
 use netbuf;
 use slab::Slab;
 
 use crate::{
-    core::types::ClientId,
+    core::{
+        events::{TimedEvents, Timeout},
+        types::ClientId,
+    },
     handlers,
     handlers::{IoResult, IoTask, ServerState},
     protocol::{messages::HwServerMessage::Redirect, messages::*, ProtocolDecoder},
@@ -36,11 +42,11 @@
         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
     },
 };
-use std::time::Duration;
 
 const MAX_BYTES_PER_READ: usize = 2048;
-const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
-const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(30);
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5);
+const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize;
 const PING_PROBES_COUNT: u8 = 2;
 
 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
@@ -64,15 +70,15 @@
 }
 
 impl ClientSocket {
-    fn inner(&self) -> &TcpStream {
+    fn inner_mut(&mut self) -> &mut TcpStream {
         match self {
             ClientSocket::Plain(stream) => stream,
             #[cfg(feature = "tls-connections")]
-            ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+            ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(),
             #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(None) => unreachable!(),
             #[cfg(feature = "tls-connections")]
-            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
+            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(),
         }
     }
 }
@@ -83,8 +89,9 @@
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf,
-    timeout: timer::Timeout,
     pending_close: bool,
+    timeout: Timeout,
+    last_rx_time: Instant,
 }
 
 impl NetworkClient {
@@ -92,7 +99,7 @@
         id: ClientId,
         socket: ClientSocket,
         peer_addr: SocketAddr,
-        timeout: timer::Timeout,
+        timeout: Timeout,
     ) -> NetworkClient {
         NetworkClient {
             id,
@@ -100,8 +107,9 @@
             peer_addr,
             decoder: ProtocolDecoder::new(),
             buf_out: netbuf::Buf::new(),
+            pending_close: false,
             timeout,
-            pending_close: false,
+            last_rx_time: Instant::now(),
         }
     }
 
@@ -171,7 +179,7 @@
     }
 
     pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> {
-        match self.socket {
+        let result = match self.socket {
             ClientSocket::Plain(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
@@ -184,7 +192,13 @@
             ClientSocket::SslStream(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
+        };
+
+        if let Ok(_) = result {
+            self.last_rx_time = Instant::now();
         }
+
+        result
     }
 
     fn write_impl<W: Write>(
@@ -231,7 +245,7 @@
             }
         };
 
-        self.socket.inner().flush()?;
+        self.socket.inner_mut().flush()?;
         result
     }
 
@@ -243,7 +257,7 @@
         self.send_raw_msg(&msg.as_bytes());
     }
 
-    pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+    pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout {
         replace(&mut self.timeout, timeout)
     }
 
@@ -267,11 +281,11 @@
 
 #[cfg(feature = "official-server")]
 impl IoLayer {
-    fn new() -> Self {
+    fn new(waker: Waker) -> Self {
         Self {
             next_request_id: 0,
             request_queue: vec![],
-            io_thread: IoThread::new(),
+            io_thread: IoThread::new(waker),
         }
     }
 
@@ -314,6 +328,7 @@
 }
 
 struct TimerData(TimeoutEvent, ClientId);
+type NetworkTimeoutEvents = TimedEvents<TimerData, MAX_TIMEOUT>;
 
 pub struct NetworkLayer {
     listener: TcpListener,
@@ -325,47 +340,44 @@
     ssl: ServerSsl,
     #[cfg(feature = "official-server")]
     io: IoLayer,
-    timer: timer::Timer<TimerData>,
+    timeout_events: NetworkTimeoutEvents,
 }
 
-fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
-    poll.register(evented, token, Ready::readable(), PollOpt::edge())
+fn register_read<S: Source>(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> {
+    poll.registry().register(source, token, Interest::READABLE)
 }
 
 fn create_ping_timeout(
-    timer: &mut timer::Timer<TimerData>,
+    timeout_events: &mut NetworkTimeoutEvents,
     probes_count: u8,
     client_id: ClientId,
-) -> timer::Timeout {
-    timer.set_timeout(
-        SEND_PING_TIMEOUT,
+) -> Timeout {
+    timeout_events.set_timeout(
+        NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(),
         TimerData(TimeoutEvent::SendPing { probes_count }, client_id),
     )
 }
 
-fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
-    timer.set_timeout(
-        DROP_CLIENT_TIMEOUT,
+fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout {
+    timeout_events.set_timeout(
+        NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(),
         TimerData(TimeoutEvent::DropClient, client_id),
     )
 }
 
 impl NetworkLayer {
-    pub fn register(&self, poll: &Poll) -> io::Result<()> {
-        register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
+    pub fn register(&mut self, poll: &Poll) -> io::Result<()> {
+        register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?;
         #[cfg(feature = "tls-connections")]
-        register_read(poll, &self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
-        register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
-
-        #[cfg(feature = "official-server")]
-        self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
+        register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
 
         Ok(())
     }
 
     fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) {
         if let Some(ref mut client) = self.clients.get_mut(id) {
-            poll.deregister(client.socket.inner())
+            poll.registry()
+                .deregister(client.socket.inner_mut())
                 .expect("could not deregister socket");
             if client.has_pending_sends() && !is_error {
                 info!(
@@ -373,15 +385,11 @@
                     client.id, client.peer_addr
                 );
                 client.pending_close = true;
-                poll.register(
-                    client.socket.inner(),
-                    Token(id),
-                    Ready::writable(),
-                    PollOpt::edge(),
-                )
-                .unwrap_or_else(|_| {
-                    self.clients.remove(id);
-                });
+                poll.registry()
+                    .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE)
+                    .unwrap_or_else(|_| {
+                        self.clients.remove(id);
+                    });
             } else {
                 info!("client {} ({}) removed", client.id, client.peer_addr);
                 self.clients.remove(id);
@@ -394,24 +402,23 @@
     fn register_client(
         &mut self,
         poll: &Poll,
-        client_socket: ClientSocket,
+        mut client_socket: ClientSocket,
         addr: SocketAddr,
     ) -> io::Result<ClientId> {
         let entry = self.clients.vacant_entry();
         let client_id = entry.key();
 
-        poll.register(
-            client_socket.inner(),
+        poll.registry().register(
+            client_socket.inner_mut(),
             Token(client_id),
-            Ready::readable() | Ready::writable(),
-            PollOpt::edge(),
+            Interest::READABLE | Interest::WRITABLE,
         )?;
 
         let client = NetworkClient::new(
             client_id,
             client_socket,
             addr,
-            create_ping_timeout(&mut self.timer, PING_PROBES_COUNT - 1, client_id),
+            create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id),
         );
         info!("client {} ({}) added", client.id, client.peer_addr);
         entry.insert(client);
@@ -451,34 +458,45 @@
         }
     }
 
-    pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
-        while let Some(TimerData(event, client_id)) = self.timer.poll() {
-            match event {
-                TimeoutEvent::SendPing { probes_count } => {
-                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                        client.send_string(&HwServerMessage::Ping.to_raw_protocol());
-                        client.write()?;
-                        let timeout = if probes_count != 0 {
-                            create_ping_timeout(&mut self.timer, probes_count - 1, client_id)
-                        } else {
-                            create_drop_timeout(&mut self.timer, client_id)
-                        };
-                        client.replace_timeout(timeout);
+    pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> {
+        for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) {
+            if let Some(client) = self.clients.get_mut(client_id) {
+                if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT {
+                    match event {
+                        TimeoutEvent::SendPing { probes_count } => {
+                            client.send_string(&HwServerMessage::Ping.to_raw_protocol());
+                            client.write()?;
+                            let timeout = if probes_count != 0 {
+                                create_ping_timeout(
+                                    &mut self.timeout_events,
+                                    probes_count - 1,
+                                    client_id,
+                                )
+                            } else {
+                                create_drop_timeout(&mut self.timeout_events, client_id)
+                            };
+                            client.replace_timeout(timeout);
+                        }
+                        TimeoutEvent::DropClient => {
+                            client.send_string(
+                                &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
+                            );
+                            let _res = client.write();
+
+                            self.operation_failed(
+                                poll,
+                                client_id,
+                                &ErrorKind::TimedOut.into(),
+                                "No ping response",
+                            )?;
+                        }
                     }
-                }
-                TimeoutEvent::DropClient => {
-                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                        client.send_string(
-                            &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
-                        );
-                        let _res = client.write();
-                    }
-                    self.operation_failed(
-                        poll,
+                } else {
+                    client.replace_timeout(create_ping_timeout(
+                        &mut self.timeout_events,
+                        PING_PROBES_COUNT - 1,
                         client_id,
-                        &ErrorKind::TimedOut.into(),
-                        "No ping response",
-                    )?;
+                    ));
                 }
             }
         }
@@ -576,12 +594,6 @@
 
     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
-            let timeout = client.replace_timeout(create_ping_timeout(
-                &mut self.timer,
-                PING_PROBES_COUNT - 1,
-                client_id,
-            ));
-            self.timer.cancel_timeout(&timeout);
             client.read()
         } else {
             warn!("invalid readable client: {}", client_id);
@@ -657,7 +669,7 @@
     }
 
     pub fn has_pending_operations(&self) -> bool {
-        !self.pending.is_empty()
+        !self.pending.is_empty() || !self.timeout_events.is_empty()
     }
 
     pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
@@ -731,13 +743,17 @@
         }
     }
 
-    pub fn build(self) -> NetworkLayer {
+    pub fn build(self, poll: &Poll) -> NetworkLayer {
         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
 
         let clients = Slab::with_capacity(self.clients_capacity);
         let pending = HashSet::with_capacity(2 * self.clients_capacity);
         let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
-        let timer = timer::Builder::default().build();
+        let timeout_events = NetworkTimeoutEvents::new();
+
+        #[cfg(feature = "official-server")]
+        let waker = Waker::new(poll.registry(), utils::IO_TOKEN)
+            .expect("Unable to create a waker for the IO thread");
 
         NetworkLayer {
             listener: self.listener.expect("No listener provided"),
@@ -750,8 +766,8 @@
                 self.secure_listener.expect("No secure listener provided"),
             ),
             #[cfg(feature = "official-server")]
-            io: IoLayer::new(),
-            timer,
+            io: IoLayer::new(waker),
+            timeout_events,
         }
     }
 }