diff -r ed3b510b860c -r 6af892a0a4b8 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Mon Jun 21 20:11:22 2021 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Tue Jun 22 01:41:33 2021 +0300 @@ -6,19 +6,25 @@ io::{Error, ErrorKind, Read, Write}, mem::{replace, swap}, net::{IpAddr, Ipv4Addr, SocketAddr}, + num::NonZeroU32, + time::Duration, + time::Instant, }; use log::*; use mio::{ + event::Source, net::{TcpListener, TcpStream}, - Evented, Poll, PollOpt, Ready, Token, + Interest, Poll, Token, Waker, }; -use mio_extras::timer; use netbuf; use slab::Slab; use crate::{ - core::types::ClientId, + core::{ + events::{TimedEvents, Timeout}, + types::ClientId, + }, handlers, handlers::{IoResult, IoTask, ServerState}, protocol::{messages::HwServerMessage::Redirect, messages::*, ProtocolDecoder}, @@ -36,11 +42,11 @@ SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, }, }; -use std::time::Duration; const MAX_BYTES_PER_READ: usize = 2048; -const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30); -const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(30); +const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5); +const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5); +const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize; const PING_PROBES_COUNT: u8 = 2; #[derive(Hash, Eq, PartialEq, Copy, Clone)] @@ -64,15 +70,15 @@ } impl ClientSocket { - fn inner(&self) -> &TcpStream { + fn inner_mut(&mut self) -> &mut TcpStream { match self { ClientSocket::Plain(stream) => stream, #[cfg(feature = "tls-connections")] - ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), + ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(), #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(None) => unreachable!(), #[cfg(feature = "tls-connections")] - ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), + ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(), } } } @@ -83,8 +89,9 @@ peer_addr: SocketAddr, decoder: ProtocolDecoder, buf_out: netbuf::Buf, - timeout: timer::Timeout, pending_close: bool, + timeout: Timeout, + last_rx_time: Instant, } impl NetworkClient { @@ -92,7 +99,7 @@ id: ClientId, socket: ClientSocket, peer_addr: SocketAddr, - timeout: timer::Timeout, + timeout: Timeout, ) -> NetworkClient { NetworkClient { id, @@ -100,8 +107,9 @@ peer_addr, decoder: ProtocolDecoder::new(), buf_out: netbuf::Buf::new(), + pending_close: false, timeout, - pending_close: false, + last_rx_time: Instant::now(), } } @@ -171,7 +179,7 @@ } pub fn read(&mut self) -> NetworkResult> { - match self.socket { + let result = match self.socket { ClientSocket::Plain(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } @@ -184,7 +192,13 @@ ClientSocket::SslStream(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } + }; + + if let Ok(_) = result { + self.last_rx_time = Instant::now(); } + + result } fn write_impl( @@ -231,7 +245,7 @@ } }; - self.socket.inner().flush()?; + self.socket.inner_mut().flush()?; result } @@ -243,7 +257,7 @@ self.send_raw_msg(&msg.as_bytes()); } - pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout { + pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout { replace(&mut self.timeout, timeout) } @@ -267,11 +281,11 @@ #[cfg(feature = "official-server")] impl IoLayer { - fn new() -> Self { + fn new(waker: Waker) -> Self { Self { next_request_id: 0, request_queue: vec![], - io_thread: IoThread::new(), + io_thread: IoThread::new(waker), } } @@ -314,6 +328,7 @@ } struct TimerData(TimeoutEvent, ClientId); +type NetworkTimeoutEvents = TimedEvents; pub struct NetworkLayer { listener: TcpListener, @@ -325,47 +340,44 @@ ssl: ServerSsl, #[cfg(feature = "official-server")] io: IoLayer, - timer: timer::Timer, + timeout_events: NetworkTimeoutEvents, } -fn register_read(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> { - poll.register(evented, token, Ready::readable(), PollOpt::edge()) +fn register_read(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> { + poll.registry().register(source, token, Interest::READABLE) } fn create_ping_timeout( - timer: &mut timer::Timer, + timeout_events: &mut NetworkTimeoutEvents, probes_count: u8, client_id: ClientId, -) -> timer::Timeout { - timer.set_timeout( - SEND_PING_TIMEOUT, +) -> Timeout { + timeout_events.set_timeout( + NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(), TimerData(TimeoutEvent::SendPing { probes_count }, client_id), ) } -fn create_drop_timeout(timer: &mut timer::Timer, client_id: ClientId) -> timer::Timeout { - timer.set_timeout( - DROP_CLIENT_TIMEOUT, +fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout { + timeout_events.set_timeout( + NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(), TimerData(TimeoutEvent::DropClient, client_id), ) } impl NetworkLayer { - pub fn register(&self, poll: &Poll) -> io::Result<()> { - register_read(poll, &self.listener, utils::SERVER_TOKEN)?; + pub fn register(&mut self, poll: &Poll) -> io::Result<()> { + register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?; #[cfg(feature = "tls-connections")] - register_read(poll, &self.ssl.listener, utils::SECURE_SERVER_TOKEN)?; - register_read(poll, &self.timer, utils::TIMER_TOKEN)?; - - #[cfg(feature = "official-server")] - self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; + register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?; Ok(()) } fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) { if let Some(ref mut client) = self.clients.get_mut(id) { - poll.deregister(client.socket.inner()) + poll.registry() + .deregister(client.socket.inner_mut()) .expect("could not deregister socket"); if client.has_pending_sends() && !is_error { info!( @@ -373,15 +385,11 @@ client.id, client.peer_addr ); client.pending_close = true; - poll.register( - client.socket.inner(), - Token(id), - Ready::writable(), - PollOpt::edge(), - ) - .unwrap_or_else(|_| { - self.clients.remove(id); - }); + poll.registry() + .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE) + .unwrap_or_else(|_| { + self.clients.remove(id); + }); } else { info!("client {} ({}) removed", client.id, client.peer_addr); self.clients.remove(id); @@ -394,24 +402,23 @@ fn register_client( &mut self, poll: &Poll, - client_socket: ClientSocket, + mut client_socket: ClientSocket, addr: SocketAddr, ) -> io::Result { let entry = self.clients.vacant_entry(); let client_id = entry.key(); - poll.register( - client_socket.inner(), + poll.registry().register( + client_socket.inner_mut(), Token(client_id), - Ready::readable() | Ready::writable(), - PollOpt::edge(), + Interest::READABLE | Interest::WRITABLE, )?; let client = NetworkClient::new( client_id, client_socket, addr, - create_ping_timeout(&mut self.timer, PING_PROBES_COUNT - 1, client_id), + create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id), ); info!("client {} ({}) added", client.id, client.peer_addr); entry.insert(client); @@ -451,34 +458,45 @@ } } - pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> { - while let Some(TimerData(event, client_id)) = self.timer.poll() { - match event { - TimeoutEvent::SendPing { probes_count } => { - if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.send_string(&HwServerMessage::Ping.to_raw_protocol()); - client.write()?; - let timeout = if probes_count != 0 { - create_ping_timeout(&mut self.timer, probes_count - 1, client_id) - } else { - create_drop_timeout(&mut self.timer, client_id) - }; - client.replace_timeout(timeout); + pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> { + for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) { + if let Some(client) = self.clients.get_mut(client_id) { + if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT { + match event { + TimeoutEvent::SendPing { probes_count } => { + client.send_string(&HwServerMessage::Ping.to_raw_protocol()); + client.write()?; + let timeout = if probes_count != 0 { + create_ping_timeout( + &mut self.timeout_events, + probes_count - 1, + client_id, + ) + } else { + create_drop_timeout(&mut self.timeout_events, client_id) + }; + client.replace_timeout(timeout); + } + TimeoutEvent::DropClient => { + client.send_string( + &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(), + ); + let _res = client.write(); + + self.operation_failed( + poll, + client_id, + &ErrorKind::TimedOut.into(), + "No ping response", + )?; + } } - } - TimeoutEvent::DropClient => { - if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.send_string( - &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(), - ); - let _res = client.write(); - } - self.operation_failed( - poll, + } else { + client.replace_timeout(create_ping_timeout( + &mut self.timeout_events, + PING_PROBES_COUNT - 1, client_id, - &ErrorKind::TimedOut.into(), - "No ping response", - )?; + )); } } } @@ -576,12 +594,6 @@ pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { - let timeout = client.replace_timeout(create_ping_timeout( - &mut self.timer, - PING_PROBES_COUNT - 1, - client_id, - )); - self.timer.cancel_timeout(&timeout); client.read() } else { warn!("invalid readable client: {}", client_id); @@ -657,7 +669,7 @@ } pub fn has_pending_operations(&self) -> bool { - !self.pending.is_empty() + !self.pending.is_empty() || !self.timeout_events.is_empty() } pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> { @@ -731,13 +743,17 @@ } } - pub fn build(self) -> NetworkLayer { + pub fn build(self, poll: &Poll) -> NetworkLayer { let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity); let clients = Slab::with_capacity(self.clients_capacity); let pending = HashSet::with_capacity(2 * self.clients_capacity); let pending_cache = Vec::with_capacity(2 * self.clients_capacity); - let timer = timer::Builder::default().build(); + let timeout_events = NetworkTimeoutEvents::new(); + + #[cfg(feature = "official-server")] + let waker = Waker::new(poll.registry(), utils::IO_TOKEN) + .expect("Unable to create a waker for the IO thread"); NetworkLayer { listener: self.listener.expect("No listener provided"), @@ -750,8 +766,8 @@ self.secure_listener.expect("No secure listener provided"), ), #[cfg(feature = "official-server")] - io: IoLayer::new(), - timer, + io: IoLayer::new(waker), + timeout_events, } } }