--- a/rust/hedgewars-server/src/server/network.rs Mon Jun 21 20:11:22 2021 +0300
+++ b/rust/hedgewars-server/src/server/network.rs Tue Jun 22 01:41:33 2021 +0300
@@ -6,19 +6,25 @@
io::{Error, ErrorKind, Read, Write},
mem::{replace, swap},
net::{IpAddr, Ipv4Addr, SocketAddr},
+ num::NonZeroU32,
+ time::Duration,
+ time::Instant,
};
use log::*;
use mio::{
+ event::Source,
net::{TcpListener, TcpStream},
- Evented, Poll, PollOpt, Ready, Token,
+ Interest, Poll, Token, Waker,
};
-use mio_extras::timer;
use netbuf;
use slab::Slab;
use crate::{
- core::types::ClientId,
+ core::{
+ events::{TimedEvents, Timeout},
+ types::ClientId,
+ },
handlers,
handlers::{IoResult, IoTask, ServerState},
protocol::{messages::HwServerMessage::Redirect, messages::*, ProtocolDecoder},
@@ -36,11 +42,11 @@
SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
},
};
-use std::time::Duration;
const MAX_BYTES_PER_READ: usize = 2048;
-const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
-const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(30);
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5);
+const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize;
const PING_PROBES_COUNT: u8 = 2;
#[derive(Hash, Eq, PartialEq, Copy, Clone)]
@@ -64,15 +70,15 @@
}
impl ClientSocket {
- fn inner(&self) -> &TcpStream {
+ fn inner_mut(&mut self) -> &mut TcpStream {
match self {
ClientSocket::Plain(stream) => stream,
#[cfg(feature = "tls-connections")]
- ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+ ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(),
#[cfg(feature = "tls-connections")]
ClientSocket::SslHandshake(None) => unreachable!(),
#[cfg(feature = "tls-connections")]
- ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
+ ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(),
}
}
}
@@ -83,8 +89,9 @@
peer_addr: SocketAddr,
decoder: ProtocolDecoder,
buf_out: netbuf::Buf,
- timeout: timer::Timeout,
pending_close: bool,
+ timeout: Timeout,
+ last_rx_time: Instant,
}
impl NetworkClient {
@@ -92,7 +99,7 @@
id: ClientId,
socket: ClientSocket,
peer_addr: SocketAddr,
- timeout: timer::Timeout,
+ timeout: Timeout,
) -> NetworkClient {
NetworkClient {
id,
@@ -100,8 +107,9 @@
peer_addr,
decoder: ProtocolDecoder::new(),
buf_out: netbuf::Buf::new(),
+ pending_close: false,
timeout,
- pending_close: false,
+ last_rx_time: Instant::now(),
}
}
@@ -171,7 +179,7 @@
}
pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> {
- match self.socket {
+ let result = match self.socket {
ClientSocket::Plain(ref mut stream) => {
NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
}
@@ -184,7 +192,13 @@
ClientSocket::SslStream(ref mut stream) => {
NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
}
+ };
+
+ if let Ok(_) = result {
+ self.last_rx_time = Instant::now();
}
+
+ result
}
fn write_impl<W: Write>(
@@ -231,7 +245,7 @@
}
};
- self.socket.inner().flush()?;
+ self.socket.inner_mut().flush()?;
result
}
@@ -243,7 +257,7 @@
self.send_raw_msg(&msg.as_bytes());
}
- pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+ pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout {
replace(&mut self.timeout, timeout)
}
@@ -267,11 +281,11 @@
#[cfg(feature = "official-server")]
impl IoLayer {
- fn new() -> Self {
+ fn new(waker: Waker) -> Self {
Self {
next_request_id: 0,
request_queue: vec![],
- io_thread: IoThread::new(),
+ io_thread: IoThread::new(waker),
}
}
@@ -314,6 +328,7 @@
}
struct TimerData(TimeoutEvent, ClientId);
+type NetworkTimeoutEvents = TimedEvents<TimerData, MAX_TIMEOUT>;
pub struct NetworkLayer {
listener: TcpListener,
@@ -325,47 +340,44 @@
ssl: ServerSsl,
#[cfg(feature = "official-server")]
io: IoLayer,
- timer: timer::Timer<TimerData>,
+ timeout_events: NetworkTimeoutEvents,
}
-fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
- poll.register(evented, token, Ready::readable(), PollOpt::edge())
+fn register_read<S: Source>(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> {
+ poll.registry().register(source, token, Interest::READABLE)
}
fn create_ping_timeout(
- timer: &mut timer::Timer<TimerData>,
+ timeout_events: &mut NetworkTimeoutEvents,
probes_count: u8,
client_id: ClientId,
-) -> timer::Timeout {
- timer.set_timeout(
- SEND_PING_TIMEOUT,
+) -> Timeout {
+ timeout_events.set_timeout(
+ NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(),
TimerData(TimeoutEvent::SendPing { probes_count }, client_id),
)
}
-fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
- timer.set_timeout(
- DROP_CLIENT_TIMEOUT,
+fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout {
+ timeout_events.set_timeout(
+ NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(),
TimerData(TimeoutEvent::DropClient, client_id),
)
}
impl NetworkLayer {
- pub fn register(&self, poll: &Poll) -> io::Result<()> {
- register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
+ pub fn register(&mut self, poll: &Poll) -> io::Result<()> {
+ register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?;
#[cfg(feature = "tls-connections")]
- register_read(poll, &self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
- register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
-
- #[cfg(feature = "official-server")]
- self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
+ register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
Ok(())
}
fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) {
if let Some(ref mut client) = self.clients.get_mut(id) {
- poll.deregister(client.socket.inner())
+ poll.registry()
+ .deregister(client.socket.inner_mut())
.expect("could not deregister socket");
if client.has_pending_sends() && !is_error {
info!(
@@ -373,15 +385,11 @@
client.id, client.peer_addr
);
client.pending_close = true;
- poll.register(
- client.socket.inner(),
- Token(id),
- Ready::writable(),
- PollOpt::edge(),
- )
- .unwrap_or_else(|_| {
- self.clients.remove(id);
- });
+ poll.registry()
+ .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE)
+ .unwrap_or_else(|_| {
+ self.clients.remove(id);
+ });
} else {
info!("client {} ({}) removed", client.id, client.peer_addr);
self.clients.remove(id);
@@ -394,24 +402,23 @@
fn register_client(
&mut self,
poll: &Poll,
- client_socket: ClientSocket,
+ mut client_socket: ClientSocket,
addr: SocketAddr,
) -> io::Result<ClientId> {
let entry = self.clients.vacant_entry();
let client_id = entry.key();
- poll.register(
- client_socket.inner(),
+ poll.registry().register(
+ client_socket.inner_mut(),
Token(client_id),
- Ready::readable() | Ready::writable(),
- PollOpt::edge(),
+ Interest::READABLE | Interest::WRITABLE,
)?;
let client = NetworkClient::new(
client_id,
client_socket,
addr,
- create_ping_timeout(&mut self.timer, PING_PROBES_COUNT - 1, client_id),
+ create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id),
);
info!("client {} ({}) added", client.id, client.peer_addr);
entry.insert(client);
@@ -451,34 +458,45 @@
}
}
- pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
- while let Some(TimerData(event, client_id)) = self.timer.poll() {
- match event {
- TimeoutEvent::SendPing { probes_count } => {
- if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.send_string(&HwServerMessage::Ping.to_raw_protocol());
- client.write()?;
- let timeout = if probes_count != 0 {
- create_ping_timeout(&mut self.timer, probes_count - 1, client_id)
- } else {
- create_drop_timeout(&mut self.timer, client_id)
- };
- client.replace_timeout(timeout);
+ pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> {
+ for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) {
+ if let Some(client) = self.clients.get_mut(client_id) {
+ if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT {
+ match event {
+ TimeoutEvent::SendPing { probes_count } => {
+ client.send_string(&HwServerMessage::Ping.to_raw_protocol());
+ client.write()?;
+ let timeout = if probes_count != 0 {
+ create_ping_timeout(
+ &mut self.timeout_events,
+ probes_count - 1,
+ client_id,
+ )
+ } else {
+ create_drop_timeout(&mut self.timeout_events, client_id)
+ };
+ client.replace_timeout(timeout);
+ }
+ TimeoutEvent::DropClient => {
+ client.send_string(
+ &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
+ );
+ let _res = client.write();
+
+ self.operation_failed(
+ poll,
+ client_id,
+ &ErrorKind::TimedOut.into(),
+ "No ping response",
+ )?;
+ }
}
- }
- TimeoutEvent::DropClient => {
- if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.send_string(
- &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
- );
- let _res = client.write();
- }
- self.operation_failed(
- poll,
+ } else {
+ client.replace_timeout(create_ping_timeout(
+ &mut self.timeout_events,
+ PING_PROBES_COUNT - 1,
client_id,
- &ErrorKind::TimedOut.into(),
- "No ping response",
- )?;
+ ));
}
}
}
@@ -576,12 +594,6 @@
pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
- let timeout = client.replace_timeout(create_ping_timeout(
- &mut self.timer,
- PING_PROBES_COUNT - 1,
- client_id,
- ));
- self.timer.cancel_timeout(&timeout);
client.read()
} else {
warn!("invalid readable client: {}", client_id);
@@ -657,7 +669,7 @@
}
pub fn has_pending_operations(&self) -> bool {
- !self.pending.is_empty()
+ !self.pending.is_empty() || !self.timeout_events.is_empty()
}
pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
@@ -731,13 +743,17 @@
}
}
- pub fn build(self) -> NetworkLayer {
+ pub fn build(self, poll: &Poll) -> NetworkLayer {
let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
let clients = Slab::with_capacity(self.clients_capacity);
let pending = HashSet::with_capacity(2 * self.clients_capacity);
let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
- let timer = timer::Builder::default().build();
+ let timeout_events = NetworkTimeoutEvents::new();
+
+ #[cfg(feature = "official-server")]
+ let waker = Waker::new(poll.registry(), utils::IO_TOKEN)
+ .expect("Unable to create a waker for the IO thread");
NetworkLayer {
listener: self.listener.expect("No listener provided"),
@@ -750,8 +766,8 @@
self.secure_listener.expect("No secure listener provided"),
),
#[cfg(feature = "official-server")]
- io: IoLayer::new(),
- timer,
+ io: IoLayer::new(waker),
+ timeout_events,
}
}
}