update mio
authoralfadur
Tue, 22 Jun 2021 01:41:33 +0300 (2021-06-21)
changeset 15822 6af892a0a4b8
parent 15821 ed3b510b860c
child 15823 f57a3d48072b
update mio
rust/hedgewars-server/Cargo.toml
rust/hedgewars-server/src/core/events.rs
rust/hedgewars-server/src/handlers/common.rs
rust/hedgewars-server/src/main.rs
rust/hedgewars-server/src/server/io.rs
rust/hedgewars-server/src/server/network.rs
rust/hedgewars-server/src/utils.rs
--- a/rust/hedgewars-server/Cargo.toml	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/Cargo.toml	Tue Jun 22 01:41:33 2021 +0300
@@ -11,10 +11,9 @@
 
 [dependencies]
 getopts = "0.2"
-rand = "0.7"
+rand = "0.8"
 chrono = "0.4"
-mio = "0.6"
-mio-extras = "2.0"
+mio = { version = "0.7", features = ["os-poll", "net"] }
 slab = "0.4"
 netbuf = "0.4"
 nom = "5.1"
--- a/rust/hedgewars-server/src/core/events.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/core/events.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -23,6 +23,7 @@
     current_time: Instant,
     current_tick_index: u32,
     next_event_id: u32,
+    events_count: u32,
 }
 
 impl<Data, const MAX_TIMEOUT: usize> TimedEvents<Data, MAX_TIMEOUT> {
@@ -37,6 +38,7 @@
             current_time: Instant::now(),
             current_tick_index: 0,
             next_event_id: 0,
+            events_count: 0,
         }
     }
 
@@ -51,6 +53,9 @@
         let entry = self.events[tick_index as usize].vacant_entry();
         let event_index = entry.key() as u32;
         entry.insert(event);
+
+        self.events_count += 1;
+
         Timeout {
             tick_index,
             event_index,
@@ -62,6 +67,7 @@
         let events = &mut self.events[timeout.tick_index as usize];
         if matches!(events.get(timeout.event_index as usize), Some(Event { event_id: id, ..}) if *id == timeout.event_id)
         {
+            self.events_count -= 1;
             Some(events.remove(timeout.event_index as usize).data)
         } else {
             None
@@ -80,8 +86,13 @@
                     .map(|e| e.data),
             );
         }
+        self.events_count -= result.len() as u32;
         result
     }
+
+    pub fn is_empty(&self) -> bool {
+        self.events_count == 0
+    }
 }
 
 mod test {
--- a/rust/hedgewars-server/src/handlers/common.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/handlers/common.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -567,7 +567,7 @@
                 }
             }
             VoteType::NewSeed => {
-                let seed = thread_rng().gen_range(0, 1_000_000_000).to_string();
+                let seed = thread_rng().gen_range(0..1_000_000_000).to_string();
                 let cfg = GameCfg::Seed(seed);
                 response.add(cfg.to_server_msg().send_all().in_room(room_id));
                 room_control.set_config(cfg);
--- a/rust/hedgewars-server/src/main.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/main.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -4,7 +4,11 @@
 use getopts::Options;
 use log::*;
 use mio::{net::*, *};
-use std::{env, str::FromStr as _, time::Duration};
+use std::{
+    env,
+    str::FromStr as _,
+    time::{Duration, Instant},
+};
 
 mod core;
 mod handlers;
@@ -44,32 +48,35 @@
         .unwrap_or(46631);
     let address = format!("0.0.0.0:{}", port).parse().unwrap();
 
-    let listener = TcpListener::bind(&address).unwrap();
+    let listener = TcpListener::bind(address).unwrap();
 
-    let poll = Poll::new().unwrap();
+    let mut poll = Poll::new().unwrap();
     let mut hw_builder = NetworkLayerBuilder::default().with_listener(listener);
 
     #[cfg(feature = "tls-connections")]
     {
         let address = format!("0.0.0.0:{}", port + 1).parse().unwrap();
-        hw_builder = hw_builder.with_secure_listener(TcpListener::bind(&address).unwrap());
+        hw_builder = hw_builder.with_secure_listener(TcpListener::bind(address).unwrap());
     }
 
-    let mut hw_network = hw_builder.build();
+    let mut hw_network = hw_builder.build(&poll);
     hw_network.register(&poll).unwrap();
 
     let mut events = Events::with_capacity(1024);
 
+    let mut time = Instant::now();
+
     loop {
         let timeout = if hw_network.has_pending_operations() {
             Some(Duration::from_millis(1))
         } else {
             None
         };
+
         poll.poll(&mut events, timeout).unwrap();
 
         for event in events.iter() {
-            if event.readiness() & Ready::readable() == Ready::readable() {
+            if event.is_readable() {
                 match event.token() {
                     token @ (utils::SERVER_TOKEN | utils::SECURE_SERVER_TOKEN) => {
                         match hw_network.accept_client(&poll, token) {
@@ -77,10 +84,6 @@
                             Err(e) => debug!("Error accepting client: {}", e),
                         }
                     }
-                    utils::TIMER_TOKEN => match hw_network.handle_timeout(&poll) {
-                        Ok(()) => (),
-                        Err(e) => debug!("Error in timer event: {}", e),
-                    },
                     #[cfg(feature = "official-server")]
                     utils::IO_TOKEN => match hw_network.handle_io_result(&poll) {
                         Ok(()) => (),
@@ -92,12 +95,11 @@
                     },
                 }
             }
-            if event.readiness() & Ready::writable() == Ready::writable() {
+            if event.is_writable() {
                 match event.token() {
-                    utils::SERVER_TOKEN
-                    | utils::SECURE_SERVER_TOKEN
-                    | utils::TIMER_TOKEN
-                    | utils::IO_TOKEN => unreachable!(),
+                    utils::SERVER_TOKEN | utils::SECURE_SERVER_TOKEN | utils::IO_TOKEN => {
+                        unreachable!()
+                    }
                     Token(token) => match hw_network.client_writable(&poll, token) {
                         Ok(()) => (),
                         Err(e) => debug!("Error writing to client socket {}: {}", token, e),
@@ -110,5 +112,13 @@
             Ok(()) => (),
             Err(e) => debug!("Error in idle handler: {}", e),
         };
+
+        if time.elapsed() > Duration::from_secs(1) {
+            time = Instant::now();
+            match hw_network.handle_timeout(&mut poll) {
+                Ok(()) => (),
+                Err(e) => debug!("Error in timer event: {}", e),
+            }
+        }
     }
 }
--- a/rust/hedgewars-server/src/server/io.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/server/io.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -1,7 +1,7 @@
 use std::{
     fs::{File, OpenOptions},
     io::{Error, ErrorKind, Read, Result, Write},
-    sync::mpsc,
+    sync::{mpsc, Arc},
     thread,
 };
 
@@ -10,20 +10,19 @@
     server::database::Database,
 };
 use log::*;
-use mio::{Evented, Poll, PollOpt};
-use mio_extras::channel;
+use mio::{Poll, Waker};
 
 pub type RequestId = u32;
 
 pub struct IoThread {
     core_tx: mpsc::Sender<(RequestId, IoTask)>,
-    core_rx: channel::Receiver<(RequestId, IoResult)>,
+    core_rx: mpsc::Receiver<(RequestId, IoResult)>,
 }
 
 impl IoThread {
-    pub fn new() -> Self {
+    pub fn new(waker: Waker) -> Self {
         let (core_tx, io_rx) = mpsc::channel();
-        let (io_tx, core_rx) = channel::channel();
+        let (io_tx, core_rx) = mpsc::channel();
 
         let mut db = Database::new();
         db.connect("localhost");
@@ -138,6 +137,7 @@
                     }
                 };
                 io_tx.send((request_id, response));
+                waker.wake();
             }
         });
 
@@ -155,11 +155,6 @@
             Err(mpsc::TryRecvError::Disconnected) => unreachable!(),
         }
     }
-
-    pub fn register_rx(&self, poll: &mio::Poll, token: mio::Token) -> Result<()> {
-        self.core_rx
-            .register(poll, token, mio::Ready::readable(), PollOpt::edge())
-    }
 }
 
 fn save_file(filename: &str, contents: &str) -> Result<()> {
--- a/rust/hedgewars-server/src/server/network.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -6,19 +6,25 @@
     io::{Error, ErrorKind, Read, Write},
     mem::{replace, swap},
     net::{IpAddr, Ipv4Addr, SocketAddr},
+    num::NonZeroU32,
+    time::Duration,
+    time::Instant,
 };
 
 use log::*;
 use mio::{
+    event::Source,
     net::{TcpListener, TcpStream},
-    Evented, Poll, PollOpt, Ready, Token,
+    Interest, Poll, Token, Waker,
 };
-use mio_extras::timer;
 use netbuf;
 use slab::Slab;
 
 use crate::{
-    core::types::ClientId,
+    core::{
+        events::{TimedEvents, Timeout},
+        types::ClientId,
+    },
     handlers,
     handlers::{IoResult, IoTask, ServerState},
     protocol::{messages::HwServerMessage::Redirect, messages::*, ProtocolDecoder},
@@ -36,11 +42,11 @@
         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
     },
 };
-use std::time::Duration;
 
 const MAX_BYTES_PER_READ: usize = 2048;
-const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
-const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(30);
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5);
+const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize;
 const PING_PROBES_COUNT: u8 = 2;
 
 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
@@ -64,15 +70,15 @@
 }
 
 impl ClientSocket {
-    fn inner(&self) -> &TcpStream {
+    fn inner_mut(&mut self) -> &mut TcpStream {
         match self {
             ClientSocket::Plain(stream) => stream,
             #[cfg(feature = "tls-connections")]
-            ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+            ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(),
             #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(None) => unreachable!(),
             #[cfg(feature = "tls-connections")]
-            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
+            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(),
         }
     }
 }
@@ -83,8 +89,9 @@
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf,
-    timeout: timer::Timeout,
     pending_close: bool,
+    timeout: Timeout,
+    last_rx_time: Instant,
 }
 
 impl NetworkClient {
@@ -92,7 +99,7 @@
         id: ClientId,
         socket: ClientSocket,
         peer_addr: SocketAddr,
-        timeout: timer::Timeout,
+        timeout: Timeout,
     ) -> NetworkClient {
         NetworkClient {
             id,
@@ -100,8 +107,9 @@
             peer_addr,
             decoder: ProtocolDecoder::new(),
             buf_out: netbuf::Buf::new(),
+            pending_close: false,
             timeout,
-            pending_close: false,
+            last_rx_time: Instant::now(),
         }
     }
 
@@ -171,7 +179,7 @@
     }
 
     pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> {
-        match self.socket {
+        let result = match self.socket {
             ClientSocket::Plain(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
@@ -184,7 +192,13 @@
             ClientSocket::SslStream(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
+        };
+
+        if let Ok(_) = result {
+            self.last_rx_time = Instant::now();
         }
+
+        result
     }
 
     fn write_impl<W: Write>(
@@ -231,7 +245,7 @@
             }
         };
 
-        self.socket.inner().flush()?;
+        self.socket.inner_mut().flush()?;
         result
     }
 
@@ -243,7 +257,7 @@
         self.send_raw_msg(&msg.as_bytes());
     }
 
-    pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+    pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout {
         replace(&mut self.timeout, timeout)
     }
 
@@ -267,11 +281,11 @@
 
 #[cfg(feature = "official-server")]
 impl IoLayer {
-    fn new() -> Self {
+    fn new(waker: Waker) -> Self {
         Self {
             next_request_id: 0,
             request_queue: vec![],
-            io_thread: IoThread::new(),
+            io_thread: IoThread::new(waker),
         }
     }
 
@@ -314,6 +328,7 @@
 }
 
 struct TimerData(TimeoutEvent, ClientId);
+type NetworkTimeoutEvents = TimedEvents<TimerData, MAX_TIMEOUT>;
 
 pub struct NetworkLayer {
     listener: TcpListener,
@@ -325,47 +340,44 @@
     ssl: ServerSsl,
     #[cfg(feature = "official-server")]
     io: IoLayer,
-    timer: timer::Timer<TimerData>,
+    timeout_events: NetworkTimeoutEvents,
 }
 
-fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
-    poll.register(evented, token, Ready::readable(), PollOpt::edge())
+fn register_read<S: Source>(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> {
+    poll.registry().register(source, token, Interest::READABLE)
 }
 
 fn create_ping_timeout(
-    timer: &mut timer::Timer<TimerData>,
+    timeout_events: &mut NetworkTimeoutEvents,
     probes_count: u8,
     client_id: ClientId,
-) -> timer::Timeout {
-    timer.set_timeout(
-        SEND_PING_TIMEOUT,
+) -> Timeout {
+    timeout_events.set_timeout(
+        NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(),
         TimerData(TimeoutEvent::SendPing { probes_count }, client_id),
     )
 }
 
-fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
-    timer.set_timeout(
-        DROP_CLIENT_TIMEOUT,
+fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout {
+    timeout_events.set_timeout(
+        NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(),
         TimerData(TimeoutEvent::DropClient, client_id),
     )
 }
 
 impl NetworkLayer {
-    pub fn register(&self, poll: &Poll) -> io::Result<()> {
-        register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
+    pub fn register(&mut self, poll: &Poll) -> io::Result<()> {
+        register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?;
         #[cfg(feature = "tls-connections")]
-        register_read(poll, &self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
-        register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
-
-        #[cfg(feature = "official-server")]
-        self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
+        register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
 
         Ok(())
     }
 
     fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) {
         if let Some(ref mut client) = self.clients.get_mut(id) {
-            poll.deregister(client.socket.inner())
+            poll.registry()
+                .deregister(client.socket.inner_mut())
                 .expect("could not deregister socket");
             if client.has_pending_sends() && !is_error {
                 info!(
@@ -373,15 +385,11 @@
                     client.id, client.peer_addr
                 );
                 client.pending_close = true;
-                poll.register(
-                    client.socket.inner(),
-                    Token(id),
-                    Ready::writable(),
-                    PollOpt::edge(),
-                )
-                .unwrap_or_else(|_| {
-                    self.clients.remove(id);
-                });
+                poll.registry()
+                    .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE)
+                    .unwrap_or_else(|_| {
+                        self.clients.remove(id);
+                    });
             } else {
                 info!("client {} ({}) removed", client.id, client.peer_addr);
                 self.clients.remove(id);
@@ -394,24 +402,23 @@
     fn register_client(
         &mut self,
         poll: &Poll,
-        client_socket: ClientSocket,
+        mut client_socket: ClientSocket,
         addr: SocketAddr,
     ) -> io::Result<ClientId> {
         let entry = self.clients.vacant_entry();
         let client_id = entry.key();
 
-        poll.register(
-            client_socket.inner(),
+        poll.registry().register(
+            client_socket.inner_mut(),
             Token(client_id),
-            Ready::readable() | Ready::writable(),
-            PollOpt::edge(),
+            Interest::READABLE | Interest::WRITABLE,
         )?;
 
         let client = NetworkClient::new(
             client_id,
             client_socket,
             addr,
-            create_ping_timeout(&mut self.timer, PING_PROBES_COUNT - 1, client_id),
+            create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id),
         );
         info!("client {} ({}) added", client.id, client.peer_addr);
         entry.insert(client);
@@ -451,34 +458,45 @@
         }
     }
 
-    pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
-        while let Some(TimerData(event, client_id)) = self.timer.poll() {
-            match event {
-                TimeoutEvent::SendPing { probes_count } => {
-                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                        client.send_string(&HwServerMessage::Ping.to_raw_protocol());
-                        client.write()?;
-                        let timeout = if probes_count != 0 {
-                            create_ping_timeout(&mut self.timer, probes_count - 1, client_id)
-                        } else {
-                            create_drop_timeout(&mut self.timer, client_id)
-                        };
-                        client.replace_timeout(timeout);
+    pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> {
+        for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) {
+            if let Some(client) = self.clients.get_mut(client_id) {
+                if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT {
+                    match event {
+                        TimeoutEvent::SendPing { probes_count } => {
+                            client.send_string(&HwServerMessage::Ping.to_raw_protocol());
+                            client.write()?;
+                            let timeout = if probes_count != 0 {
+                                create_ping_timeout(
+                                    &mut self.timeout_events,
+                                    probes_count - 1,
+                                    client_id,
+                                )
+                            } else {
+                                create_drop_timeout(&mut self.timeout_events, client_id)
+                            };
+                            client.replace_timeout(timeout);
+                        }
+                        TimeoutEvent::DropClient => {
+                            client.send_string(
+                                &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
+                            );
+                            let _res = client.write();
+
+                            self.operation_failed(
+                                poll,
+                                client_id,
+                                &ErrorKind::TimedOut.into(),
+                                "No ping response",
+                            )?;
+                        }
                     }
-                }
-                TimeoutEvent::DropClient => {
-                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                        client.send_string(
-                            &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
-                        );
-                        let _res = client.write();
-                    }
-                    self.operation_failed(
-                        poll,
+                } else {
+                    client.replace_timeout(create_ping_timeout(
+                        &mut self.timeout_events,
+                        PING_PROBES_COUNT - 1,
                         client_id,
-                        &ErrorKind::TimedOut.into(),
-                        "No ping response",
-                    )?;
+                    ));
                 }
             }
         }
@@ -576,12 +594,6 @@
 
     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
-            let timeout = client.replace_timeout(create_ping_timeout(
-                &mut self.timer,
-                PING_PROBES_COUNT - 1,
-                client_id,
-            ));
-            self.timer.cancel_timeout(&timeout);
             client.read()
         } else {
             warn!("invalid readable client: {}", client_id);
@@ -657,7 +669,7 @@
     }
 
     pub fn has_pending_operations(&self) -> bool {
-        !self.pending.is_empty()
+        !self.pending.is_empty() || !self.timeout_events.is_empty()
     }
 
     pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
@@ -731,13 +743,17 @@
         }
     }
 
-    pub fn build(self) -> NetworkLayer {
+    pub fn build(self, poll: &Poll) -> NetworkLayer {
         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
 
         let clients = Slab::with_capacity(self.clients_capacity);
         let pending = HashSet::with_capacity(2 * self.clients_capacity);
         let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
-        let timer = timer::Builder::default().build();
+        let timeout_events = NetworkTimeoutEvents::new();
+
+        #[cfg(feature = "official-server")]
+        let waker = Waker::new(poll.registry(), utils::IO_TOKEN)
+            .expect("Unable to create a waker for the IO thread");
 
         NetworkLayer {
             listener: self.listener.expect("No listener provided"),
@@ -750,8 +766,8 @@
                 self.secure_listener.expect("No secure listener provided"),
             ),
             #[cfg(feature = "official-server")]
-            io: IoLayer::new(),
-            timer,
+            io: IoLayer::new(waker),
+            timeout_events,
         }
     }
 }
--- a/rust/hedgewars-server/src/utils.rs	Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/utils.rs	Tue Jun 22 01:41:33 2021 +0300
@@ -5,7 +5,6 @@
 pub const SERVER_VERSION: u32 = 3;
 pub const SERVER_TOKEN: mio::Token = mio::Token(1_000_000_000);
 pub const SECURE_SERVER_TOKEN: mio::Token = mio::Token(1_000_000_001);
-pub const TIMER_TOKEN: mio::Token = mio::Token(1_000_000_002);
 pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_003);
 
 pub fn is_name_illegal(name: &str) -> bool {