rust/hedgewars-server/src/server/network.rs
changeset 15831 7d0f747afcb8
parent 15804 747278149393
child 15832 a4d505a32879
equal deleted inserted replaced
15830:ea459da15b30 15831:7d0f747afcb8
     1 extern crate slab;
     1 use bytes::{Buf, Bytes};
     2 
     2 use log::*;
       
     3 use slab::Slab;
     3 use std::{
     4 use std::{
     4     collections::HashSet,
     5     collections::HashSet,
     5     io,
     6     io,
     6     io::{Error, ErrorKind, Read, Write},
     7     io::{Error, ErrorKind, Read, Write},
       
     8     iter::Iterator,
     7     mem::{replace, swap},
     9     mem::{replace, swap},
     8     net::{IpAddr, Ipv4Addr, SocketAddr},
    10     net::{IpAddr, Ipv4Addr, SocketAddr},
     9     num::NonZeroU32,
    11     num::NonZeroU32,
    10     time::Duration,
    12     time::Duration,
    11     time::Instant,
    13     time::Instant,
    12 };
    14 };
    13 
    15 use tokio::{
    14 use log::*;
    16     io::AsyncReadExt,
    15 use mio::{
       
    16     event::Source,
       
    17     net::{TcpListener, TcpStream},
    17     net::{TcpListener, TcpStream},
    18     Interest, Poll, Token, Waker,
    18     sync::mpsc::{channel, Receiver, Sender},
    19 };
    19 };
    20 use netbuf;
       
    21 use slab::Slab;
       
    22 
    20 
    23 use crate::{
    21 use crate::{
    24     core::{
    22     core::{
    25         events::{TimedEvents, Timeout},
    23         events::{TimedEvents, Timeout},
    26         types::ClientId,
    24         types::ClientId,
    28     handlers,
    26     handlers,
    29     handlers::{IoResult, IoTask, ServerState},
    27     handlers::{IoResult, IoTask, ServerState},
    30     protocol::ProtocolDecoder,
    28     protocol::ProtocolDecoder,
    31     utils,
    29     utils,
    32 };
    30 };
    33 use hedgewars_network_protocol::{messages::HwServerMessage::Redirect, messages::*};
    31 use hedgewars_network_protocol::{
    34 
    32     messages::HwServerMessage::Redirect, messages::*, parser::server_message,
    35 #[cfg(feature = "official-server")]
    33 };
    36 use super::io::{IoThread, RequestId};
    34 use tokio::io::AsyncWriteExt;
    37 
    35 
    38 #[cfg(feature = "tls-connections")]
    36 enum ClientUpdateData {
    39 use openssl::{
    37     Message(HwProtocolMessage),
    40     error::ErrorStack,
    38     Error(String),
    41     ssl::{
    39 }
    42         HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
    40 
    43         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
    41 struct ClientUpdate {
    44     },
    42     client_id: ClientId,
    45 };
    43     data: ClientUpdateData,
    46 
    44 }
    47 const MAX_BYTES_PER_READ: usize = 2048;
    45 
    48 const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5);
    46 struct ClientUpdateSender {
    49 const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5);
    47     client_id: ClientId,
    50 const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize;
    48     sender: Sender<ClientUpdate>,
    51 const PING_PROBES_COUNT: u8 = 2;
    49 }
    52 
    50 
    53 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
    51 impl ClientUpdateSender {
    54 pub enum NetworkClientState {
    52     async fn send(&mut self, data: ClientUpdateData) -> bool {
    55     Idle,
    53         self.sender
    56     NeedsWrite,
    54             .send(ClientUpdate {
    57     NeedsRead,
    55                 client_id: self.client_id,
    58     Closed,
    56                 data,
    59     #[cfg(feature = "tls-connections")]
    57             })
    60     Connected,
    58             .await
    61 }
    59             .is_ok()
    62 
    60     }
    63 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    61 }
    64 
    62 
    65 pub enum ClientSocket {
    63 struct NetworkClient {
    66     Plain(TcpStream),
       
    67     #[cfg(feature = "tls-connections")]
       
    68     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
       
    69     #[cfg(feature = "tls-connections")]
       
    70     SslStream(SslStream<TcpStream>),
       
    71 }
       
    72 
       
    73 impl ClientSocket {
       
    74     fn inner_mut(&mut self) -> &mut TcpStream {
       
    75         match self {
       
    76             ClientSocket::Plain(stream) => stream,
       
    77             #[cfg(feature = "tls-connections")]
       
    78             ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(),
       
    79             #[cfg(feature = "tls-connections")]
       
    80             ClientSocket::SslHandshake(None) => unreachable!(),
       
    81             #[cfg(feature = "tls-connections")]
       
    82             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(),
       
    83         }
       
    84     }
       
    85 }
       
    86 
       
    87 pub struct NetworkClient {
       
    88     id: ClientId,
    64     id: ClientId,
    89     socket: ClientSocket,
    65     socket: TcpStream,
       
    66     receiver: Receiver<Bytes>,
    90     peer_addr: SocketAddr,
    67     peer_addr: SocketAddr,
    91     decoder: ProtocolDecoder,
    68     decoder: ProtocolDecoder,
    92     buf_out: netbuf::Buf,
       
    93     pending_close: bool,
       
    94     timeout: Timeout,
       
    95     last_rx_time: Instant,
       
    96 }
    69 }
    97 
    70 
    98 impl NetworkClient {
    71 impl NetworkClient {
    99     pub fn new(
    72     fn new(
   100         id: ClientId,
    73         id: ClientId,
   101         socket: ClientSocket,
    74         socket: TcpStream,
   102         peer_addr: SocketAddr,
    75         peer_addr: SocketAddr,
   103         timeout: Timeout,
    76         receiver: Receiver<Bytes>,
   104     ) -> NetworkClient {
    77     ) -> Self {
   105         NetworkClient {
    78         Self {
   106             id,
    79             id,
   107             socket,
    80             socket,
   108             peer_addr,
    81             peer_addr,
       
    82             receiver,
   109             decoder: ProtocolDecoder::new(),
    83             decoder: ProtocolDecoder::new(),
   110             buf_out: netbuf::Buf::new(),
    84         }
   111             pending_close: false,
    85     }
   112             timeout,
    86 
   113             last_rx_time: Instant::now(),
    87     async fn read(&mut self) -> Option<HwProtocolMessage> {
   114         }
    88         self.decoder.read_from(&mut self.socket).await
   115     }
    89     }
   116 
    90 
   117     #[cfg(feature = "tls-connections")]
    91     async fn write(&mut self, mut data: Bytes) -> bool {
   118     fn handshake_impl(
    92         !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0)
   119         &mut self,
    93     }
   120         handshake: MidHandshakeSslStream<TcpStream>,
    94 
   121     ) -> io::Result<NetworkClientState> {
    95     async fn run(mut self, sender: Sender<ClientUpdate>) {
   122         match handshake.handshake() {
    96         use ClientUpdateData::*;
   123             Ok(stream) => {
    97         let mut sender = ClientUpdateSender {
   124                 self.socket = ClientSocket::SslStream(stream);
    98             client_id: self.id,
   125                 debug!(
    99             sender,
   126                     "TLS handshake with {} ({}) completed",
       
   127                     self.id, self.peer_addr
       
   128                 );
       
   129                 Ok(NetworkClientState::Connected)
       
   130             }
       
   131             Err(HandshakeError::WouldBlock(new_handshake)) => {
       
   132                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
       
   133                 Ok(NetworkClientState::Idle)
       
   134             }
       
   135             Err(HandshakeError::Failure(new_handshake)) => {
       
   136                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
       
   137                 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
       
   138                 Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   139             }
       
   140             Err(HandshakeError::SetupFailure(_)) => unreachable!(),
       
   141         }
       
   142     }
       
   143 
       
   144     fn read_impl<R: Read>(
       
   145         decoder: &mut ProtocolDecoder,
       
   146         source: &mut R,
       
   147         id: ClientId,
       
   148         addr: &SocketAddr,
       
   149     ) -> NetworkResult<Vec<HwProtocolMessage>> {
       
   150         let mut bytes_read = 0;
       
   151         let result = loop {
       
   152             match decoder.read_from(source) {
       
   153                 Ok(bytes) => {
       
   154                     debug!("Client {}: read {} bytes", id, bytes);
       
   155                     bytes_read += bytes;
       
   156                     if bytes == 0 {
       
   157                         let result = if bytes_read == 0 {
       
   158                             info!("EOF for client {} ({})", id, addr);
       
   159                             (Vec::new(), NetworkClientState::Closed)
       
   160                         } else {
       
   161                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
       
   162                         };
       
   163                         break Ok(result);
       
   164                     } else if bytes_read >= MAX_BYTES_PER_READ {
       
   165                         break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead));
       
   166                     }
       
   167                 }
       
   168                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
       
   169                     let messages = if bytes_read == 0 {
       
   170                         Vec::new()
       
   171                     } else {
       
   172                         decoder.extract_messages()
       
   173                     };
       
   174                     break Ok((messages, NetworkClientState::Idle));
       
   175                 }
       
   176                 Err(error) => break Err(error),
       
   177             }
       
   178         };
   100         };
   179         result
   101 
   180     }
   102         loop {
   181 
   103             tokio::select! {
   182     pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> {
   104                 server_message = self.receiver.recv() => {
   183         let result = match self.socket {
   105                     match server_message {
   184             ClientSocket::Plain(ref mut stream) => {
   106                         Some(message) => if !self.write(message).await {
   185                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   107                             sender.send(Error("Connection reset by peer".to_string())).await;
   186             }
   108                             break;
   187             #[cfg(feature = "tls-connections")]
   109                         }
   188             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   110                         None => {
   189                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   111                             break;
   190                 Ok((Vec::new(), self.handshake_impl(handshake)?))
   112                         }
   191             }
   113                     }
   192             #[cfg(feature = "tls-connections")]
   114                 }
   193             ClientSocket::SslStream(ref mut stream) => {
   115                 client_message = self.decoder.read_from(&mut self.socket) => {
   194                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   116                      match client_message {
   195             }
   117                         Some(message) => {
   196         };
   118                             if !sender.send(Message(message)).await {
   197 
   119                                 break;
   198         if let Ok(_) = result {
   120                             }
   199             self.last_rx_time = Instant::now();
   121                         }
   200         }
   122                         None => {
   201 
   123                             sender.send(Error("Connection reset by peer".to_string())).await;
   202         result
   124                             break;
   203     }
   125                         }
   204 
   126                     }
   205     fn write_impl<W: Write>(
   127                 }
   206         buf_out: &mut netbuf::Buf,
   128             }
   207         destination: &mut W,
   129         }
   208         close_on_empty: bool,
   130     }
   209     ) -> NetworkResult<()> {
   131 }
   210         let result = loop {
       
   211             match buf_out.write_to(destination) {
       
   212                 Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
       
   213                     let status = if buf_out.is_empty() && close_on_empty {
       
   214                         NetworkClientState::Closed
       
   215                     } else {
       
   216                         NetworkClientState::Idle
       
   217                     };
       
   218                     break Ok(((), status));
       
   219                 }
       
   220                 Ok(_) => (),
       
   221                 Err(ref error)
       
   222                     if error.kind() == ErrorKind::Interrupted
       
   223                         || error.kind() == ErrorKind::WouldBlock =>
       
   224                 {
       
   225                     break Ok(((), NetworkClientState::NeedsWrite));
       
   226                 }
       
   227                 Err(error) => break Err(error),
       
   228             }
       
   229         };
       
   230         result
       
   231     }
       
   232 
       
   233     pub fn write(&mut self) -> NetworkResult<()> {
       
   234         let result = match self.socket {
       
   235             ClientSocket::Plain(ref mut stream) => {
       
   236                 NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
       
   237             }
       
   238             #[cfg(feature = "tls-connections")]
       
   239             ClientSocket::SslHandshake(ref mut handshake_opt) => {
       
   240                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
       
   241                 Ok(((), self.handshake_impl(handshake)?))
       
   242             }
       
   243             #[cfg(feature = "tls-connections")]
       
   244             ClientSocket::SslStream(ref mut stream) => {
       
   245                 NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
       
   246             }
       
   247         };
       
   248 
       
   249         self.socket.inner_mut().flush()?;
       
   250         result
       
   251     }
       
   252 
       
   253     pub fn send_raw_msg(&mut self, msg: &[u8]) {
       
   254         self.buf_out.write_all(msg).unwrap();
       
   255     }
       
   256 
       
   257     pub fn send_string(&mut self, msg: &str) {
       
   258         self.send_raw_msg(&msg.as_bytes());
       
   259     }
       
   260 
       
   261     pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout {
       
   262         replace(&mut self.timeout, timeout)
       
   263     }
       
   264 
       
   265     pub fn has_pending_sends(&self) -> bool {
       
   266         !self.buf_out.is_empty()
       
   267     }
       
   268 }
       
   269 
       
   270 #[cfg(feature = "tls-connections")]
       
   271 struct ServerSsl {
       
   272     listener: TcpListener,
       
   273     context: SslContext,
       
   274 }
       
   275 
       
   276 #[cfg(feature = "official-server")]
       
   277 pub struct IoLayer {
       
   278     next_request_id: RequestId,
       
   279     request_queue: Vec<(RequestId, ClientId)>,
       
   280     io_thread: IoThread,
       
   281 }
       
   282 
       
   283 #[cfg(feature = "official-server")]
       
   284 impl IoLayer {
       
   285     fn new(waker: Waker) -> Self {
       
   286         Self {
       
   287             next_request_id: 0,
       
   288             request_queue: vec![],
       
   289             io_thread: IoThread::new(waker),
       
   290         }
       
   291     }
       
   292 
       
   293     fn send(&mut self, client_id: ClientId, task: IoTask) {
       
   294         let request_id = self.next_request_id;
       
   295         self.next_request_id += 1;
       
   296         self.request_queue.push((request_id, client_id));
       
   297         self.io_thread.send(request_id, task);
       
   298     }
       
   299 
       
   300     fn try_recv(&mut self) -> Option<(ClientId, IoResult)> {
       
   301         let (request_id, result) = self.io_thread.try_recv()?;
       
   302         if let Some(index) = self
       
   303             .request_queue
       
   304             .iter()
       
   305             .position(|(id, _)| *id == request_id)
       
   306         {
       
   307             let (_, client_id) = self.request_queue.swap_remove(index);
       
   308             Some((client_id, result))
       
   309         } else {
       
   310             None
       
   311         }
       
   312     }
       
   313 
       
   314     fn cancel(&mut self, client_id: ClientId) {
       
   315         let mut index = 0;
       
   316         while index < self.request_queue.len() {
       
   317             if self.request_queue[index].1 == client_id {
       
   318                 self.request_queue.swap_remove(index);
       
   319             } else {
       
   320                 index += 1;
       
   321             }
       
   322         }
       
   323     }
       
   324 }
       
   325 
       
   326 enum TimeoutEvent {
       
   327     SendPing { probes_count: u8 },
       
   328     DropClient,
       
   329 }
       
   330 
       
   331 struct TimerData(TimeoutEvent, ClientId);
       
   332 type NetworkTimeoutEvents = TimedEvents<TimerData, MAX_TIMEOUT>;
       
   333 
   132 
   334 pub struct NetworkLayer {
   133 pub struct NetworkLayer {
   335     listener: TcpListener,
   134     listener: TcpListener,
   336     server_state: ServerState,
   135     server_state: ServerState,
   337     clients: Slab<NetworkClient>,
   136     clients: Slab<Sender<Bytes>>,
   338     pending: HashSet<(ClientId, NetworkClientState)>,
       
   339     pending_cache: Vec<(ClientId, NetworkClientState)>,
       
   340     #[cfg(feature = "tls-connections")]
       
   341     ssl: ServerSsl,
       
   342     #[cfg(feature = "official-server")]
       
   343     io: IoLayer,
       
   344     timeout_events: NetworkTimeoutEvents,
       
   345 }
       
   346 
       
   347 fn register_read<S: Source>(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> {
       
   348     poll.registry().register(source, token, Interest::READABLE)
       
   349 }
       
   350 
       
   351 fn create_ping_timeout(
       
   352     timeout_events: &mut NetworkTimeoutEvents,
       
   353     probes_count: u8,
       
   354     client_id: ClientId,
       
   355 ) -> Timeout {
       
   356     timeout_events.set_timeout(
       
   357         NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(),
       
   358         TimerData(TimeoutEvent::SendPing { probes_count }, client_id),
       
   359     )
       
   360 }
       
   361 
       
   362 fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout {
       
   363     timeout_events.set_timeout(
       
   364         NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(),
       
   365         TimerData(TimeoutEvent::DropClient, client_id),
       
   366     )
       
   367 }
   137 }
   368 
   138 
   369 impl NetworkLayer {
   139 impl NetworkLayer {
   370     pub fn register(&mut self, poll: &Poll) -> io::Result<()> {
   140     pub async fn run(&mut self) {
   371         register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?;
   141         let (update_tx, mut update_rx) = channel(128);
   372         #[cfg(feature = "tls-connections")]
   142 
   373         register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
   143         loop {
   374 
   144             tokio::select! {
   375         Ok(())
   145                 Ok((stream, addr)) = self.listener.accept() => {
   376     }
   146                     if let Some(client) = self.create_client(stream, addr).await {
   377 
   147                         tokio::spawn(client.run(update_tx.clone()));
   378     fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) {
   148                     }
   379         if let Some(ref mut client) = self.clients.get_mut(id) {
   149                 }
   380             poll.registry()
   150                 client_message = update_rx.recv(), if !self.clients.is_empty() => {
   381                 .deregister(client.socket.inner_mut())
   151                     use ClientUpdateData::*;
   382                 .expect("could not deregister socket");
   152                     match client_message {
   383             if client.has_pending_sends() && !is_error {
   153                         Some(ClientUpdate{ client_id, data: Message(message) } ) => {
   384                 info!(
   154                             self.handle_message(client_id, message).await;
   385                     "client {} ({}) pending removal",
   155                         }
   386                     client.id, client.peer_addr
   156                         Some(ClientUpdate{ client_id, .. } ) => {
   387                 );
   157                             let mut response = handlers::Response::new(client_id);
   388                 client.pending_close = true;
   158                             handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
   389                 poll.registry()
   159                             self.handle_response(response).await;
   390                     .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE)
   160                         }
   391                     .unwrap_or_else(|_| {
   161                         None => unreachable!()
   392                         self.clients.remove(id);
   162                     }
   393                     });
   163                 }
   394             } else {
   164             }
   395                 info!("client {} ({}) removed", client.id, client.peer_addr);
   165         }
   396                 self.clients.remove(id);
   166     }
   397             }
   167 
   398             #[cfg(feature = "official-server")]
   168     async fn create_client(
   399             self.io.cancel(id);
       
   400         }
       
   401     }
       
   402 
       
   403     fn register_client(
       
   404         &mut self,
   169         &mut self,
   405         poll: &Poll,
   170         stream: TcpStream,
   406         mut client_socket: ClientSocket,
       
   407         addr: SocketAddr,
   171         addr: SocketAddr,
   408     ) -> io::Result<ClientId> {
   172     ) -> Option<NetworkClient> {
   409         let entry = self.clients.vacant_entry();
   173         let entry = self.clients.vacant_entry();
   410         let client_id = entry.key();
   174         let client_id = entry.key();
   411 
   175         let (tx, rx) = channel(16);
   412         poll.registry().register(
   176         entry.insert(tx);
   413             client_socket.inner_mut(),
   177 
   414             Token(client_id),
   178         let client = NetworkClient::new(client_id, stream, addr, rx);
   415             Interest::READABLE | Interest::WRITABLE,
   179 
   416         )?;
       
   417 
       
   418         let client = NetworkClient::new(
       
   419             client_id,
       
   420             client_socket,
       
   421             addr,
       
   422             create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id),
       
   423         );
       
   424         info!("client {} ({}) added", client.id, client.peer_addr);
   180         info!("client {} ({}) added", client.id, client.peer_addr);
   425         entry.insert(client);
   181 
   426 
       
   427         Ok(client_id)
       
   428     }
       
   429 
       
   430     fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
       
   431         if response.is_empty() {
       
   432             return;
       
   433         }
       
   434 
       
   435         debug!("{} pending server messages", response.len());
       
   436         let output = response.extract_messages(&mut self.server_state.server);
       
   437         for (clients, message) in output {
       
   438             debug!("Message {:?} to {:?}", message, clients);
       
   439             let msg_string = message.to_raw_protocol();
       
   440             for client_id in clients {
       
   441                 if let Some(client) = self.clients.get_mut(client_id) {
       
   442                     client.send_string(&msg_string);
       
   443                     self.pending
       
   444                         .insert((client_id, NetworkClientState::NeedsWrite));
       
   445                 }
       
   446             }
       
   447         }
       
   448 
       
   449         for client_id in response.extract_removed_clients() {
       
   450             self.deregister_client(poll, client_id, false);
       
   451         }
       
   452 
       
   453         #[cfg(feature = "official-server")]
       
   454         {
       
   455             let client_id = response.client_id();
       
   456             for task in response.extract_io_tasks() {
       
   457                 self.io.send(client_id, task);
       
   458             }
       
   459         }
       
   460     }
       
   461 
       
   462     pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> {
       
   463         for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) {
       
   464             if let Some(client) = self.clients.get_mut(client_id) {
       
   465                 if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT {
       
   466                     match event {
       
   467                         TimeoutEvent::SendPing { probes_count } => {
       
   468                             client.send_string(&HwServerMessage::Ping.to_raw_protocol());
       
   469                             client.write()?;
       
   470                             let timeout = if probes_count != 0 {
       
   471                                 create_ping_timeout(
       
   472                                     &mut self.timeout_events,
       
   473                                     probes_count - 1,
       
   474                                     client_id,
       
   475                                 )
       
   476                             } else {
       
   477                                 create_drop_timeout(&mut self.timeout_events, client_id)
       
   478                             };
       
   479                             client.replace_timeout(timeout);
       
   480                         }
       
   481                         TimeoutEvent::DropClient => {
       
   482                             client.send_string(
       
   483                                 &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
       
   484                             );
       
   485                             let _res = client.write();
       
   486 
       
   487                             self.operation_failed(
       
   488                                 poll,
       
   489                                 client_id,
       
   490                                 &ErrorKind::TimedOut.into(),
       
   491                                 "No ping response",
       
   492                             )?;
       
   493                         }
       
   494                     }
       
   495                 } else {
       
   496                     client.replace_timeout(create_ping_timeout(
       
   497                         &mut self.timeout_events,
       
   498                         PING_PROBES_COUNT - 1,
       
   499                         client_id,
       
   500                     ));
       
   501                 }
       
   502             }
       
   503         }
       
   504         Ok(())
       
   505     }
       
   506 
       
   507     #[cfg(feature = "official-server")]
       
   508     pub fn handle_io_result(&mut self, poll: &Poll) -> io::Result<()> {
       
   509         while let Some((client_id, result)) = self.io.try_recv() {
       
   510             debug!("Handling io result {:?} for client {}", result, client_id);
       
   511             let mut response = handlers::Response::new(client_id);
       
   512             handlers::handle_io_result(&mut self.server_state, client_id, &mut response, result);
       
   513             self.handle_response(response, poll);
       
   514         }
       
   515         Ok(())
       
   516     }
       
   517 
       
   518     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
       
   519         Ok(ClientSocket::Plain(socket))
       
   520     }
       
   521 
       
   522     #[cfg(feature = "tls-connections")]
       
   523     fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
       
   524         let ssl = Ssl::new(&self.ssl.context).unwrap();
       
   525         let mut builder = SslStreamBuilder::new(ssl, socket);
       
   526         builder.set_accept_state();
       
   527         match builder.handshake() {
       
   528             Ok(stream) => Ok(ClientSocket::SslStream(stream)),
       
   529             Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))),
       
   530             Err(e) => {
       
   531                 debug!("OpenSSL handshake failed: {}", e);
       
   532                 Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   533             }
       
   534         }
       
   535     }
       
   536 
       
   537     fn init_client(&mut self, poll: &Poll, client_id: ClientId) {
       
   538         let mut response = handlers::Response::new(client_id);
   182         let mut response = handlers::Response::new(client_id);
   539 
   183 
   540         if let ClientSocket::Plain(_) = self.clients[client_id].socket {
   184         let added = if let IpAddr::V4(addr) = client.peer_addr.ip() {
   541             #[cfg(feature = "tls-connections")]
       
   542             response.add(Redirect(self.ssl.listener.local_addr().unwrap().port()).send_self())
       
   543         }
       
   544 
       
   545         if let IpAddr::V4(addr) = self.clients[client_id].peer_addr.ip() {
       
   546             handlers::handle_client_accept(
   185             handlers::handle_client_accept(
   547                 &mut self.server_state,
   186                 &mut self.server_state,
   548                 client_id,
   187                 client_id,
   549                 &mut response,
   188                 &mut response,
   550                 addr.octets(),
   189                 addr.octets(),
   551                 addr.is_loopback(),
   190                 addr.is_loopback(),
   552             );
   191             )
   553             self.handle_response(response, poll);
       
   554         } else {
   192         } else {
   555             todo!("implement something")
   193             todo!("implement something")
   556         }
   194         };
   557     }
   195 
   558 
   196         self.handle_response(response).await;
   559     pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> {
   197 
   560         match server_token {
   198         if added {
   561             utils::SERVER_TOKEN => {
   199             Some(client)
   562                 let (client_socket, addr) = self.listener.accept()?;
       
   563                 info!("Connected(plaintext): {}", addr);
       
   564                 let client_id =
       
   565                     self.register_client(poll, self.create_client_socket(client_socket)?, addr)?;
       
   566                 self.init_client(poll, client_id);
       
   567             }
       
   568             #[cfg(feature = "tls-connections")]
       
   569             utils::SECURE_SERVER_TOKEN => {
       
   570                 let (client_socket, addr) = self.ssl.listener.accept()?;
       
   571                 info!("Connected(TLS): {}", addr);
       
   572                 self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr)?;
       
   573             }
       
   574             _ => unreachable!(),
       
   575         }
       
   576 
       
   577         Ok(())
       
   578     }
       
   579 
       
   580     fn operation_failed(
       
   581         &mut self,
       
   582         poll: &Poll,
       
   583         client_id: ClientId,
       
   584         error: &Error,
       
   585         msg: &str,
       
   586     ) -> io::Result<()> {
       
   587         let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
       
   588             client.peer_addr
       
   589         } else {
   200         } else {
   590             SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
   201             None
   591         };
   202         }
   592         debug!("{}({}): {}", msg, addr, error);
   203     }
   593         self.client_error(poll, client_id)
   204 
   594     }
   205     async fn handle_message(&mut self, client_id: ClientId, message: HwProtocolMessage) {
   595 
   206         debug!("Handling message {:?} for client {}", message, client_id);
   596     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
       
   597         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
       
   598             client.read()
       
   599         } else {
       
   600             warn!("invalid readable client: {}", client_id);
       
   601             Ok((Vec::new(), NetworkClientState::Idle))
       
   602         };
       
   603 
       
   604         let mut response = handlers::Response::new(client_id);
   207         let mut response = handlers::Response::new(client_id);
   605 
   208         handlers::handle(&mut self.server_state, client_id, &mut response, message);
   606         match messages {
   209         self.handle_response(response).await;
   607             Ok((messages, state)) => {
   210     }
   608                 for message in messages {
   211 
   609                     debug!("Handling message {:?} for client {}", message, client_id);
   212     async fn handle_response(&mut self, mut response: handlers::Response) {
   610                     handlers::handle(&mut self.server_state, client_id, &mut response, message);
   213         if response.is_empty() {
   611                 }
   214             return;
   612                 match state {
   215         }
   613                     NetworkClientState::NeedsRead => {
   216 
   614                         self.pending.insert((client_id, state));
   217         debug!("{} pending server messages", response.len());
   615                     }
   218         let output = response.extract_messages(&mut self.server_state.server);
   616                     NetworkClientState::Closed => self.client_error(&poll, client_id)?,
   219         for (clients, message) in output {
   617                     #[cfg(feature = "tls-connections")]
   220             debug!("Message {:?} to {:?}", message, clients);
   618                     NetworkClientState::Connected => self.init_client(poll, client_id),
   221             Self::send_message(&mut self.clients, message, clients.iter().cloned()).await;
   619                     _ => {}
   222         }
   620                 };
   223 
   621             }
   224         for client_id in response.extract_removed_clients() {
   622             Err(e) => self.operation_failed(
   225             if self.clients.contains(client_id) {
   623                 poll,
   226                 self.clients.remove(client_id);
   624                 client_id,
   227             }
   625                 &e,
   228             info!("Client {} removed", client_id);
   626                 "Error while reading from client socket",
   229         }
   627             )?,
   230     }
   628         }
   231 
   629 
   232     async fn send_message<I>(
   630         self.handle_response(response, poll);
   233         clients: &mut Slab<Sender<Bytes>>,
   631 
   234         message: HwServerMessage,
   632         Ok(())
   235         to_clients: I,
   633     }
   236     ) where
   634 
   237         I: Iterator<Item = ClientId>,
   635     pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   238     {
   636         let result = if let Some(ref mut client) = self.clients.get_mut(client_id) {
   239         let msg_string = message.to_raw_protocol();
   637             client.write()
   240         let bytes = Bytes::copy_from_slice(msg_string.as_bytes());
   638         } else {
   241         for client_id in to_clients {
   639             warn!("invalid writable client: {}", client_id);
   242             if let Some(client) = clients.get_mut(client_id) {
   640             Ok(((), NetworkClientState::Idle))
   243                 if !client.send(bytes.clone()).await.is_ok() {
   641         };
   244                     clients.remove(client_id);
   642 
   245                 }
   643         match result {
   246             }
   644             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
   247         }
   645                 self.pending.insert((client_id, state));
       
   646             }
       
   647             Ok(((), state)) if state == NetworkClientState::Closed => {
       
   648                 self.deregister_client(poll, client_id, false);
       
   649             }
       
   650             Ok(_) => (),
       
   651             Err(e) => {
       
   652                 self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
       
   653             }
       
   654         }
       
   655 
       
   656         Ok(())
       
   657     }
       
   658 
       
   659     pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
       
   660         let pending_close = self.clients[client_id].pending_close;
       
   661         self.deregister_client(poll, client_id, true);
       
   662 
       
   663         if !pending_close {
       
   664             let mut response = handlers::Response::new(client_id);
       
   665             handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
       
   666             self.handle_response(response, poll);
       
   667         }
       
   668 
       
   669         Ok(())
       
   670     }
       
   671 
       
   672     pub fn has_pending_operations(&self) -> bool {
       
   673         !self.pending.is_empty() || !self.timeout_events.is_empty()
       
   674     }
       
   675 
       
   676     pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
       
   677         if self.has_pending_operations() {
       
   678             let mut cache = replace(&mut self.pending_cache, Vec::new());
       
   679             cache.extend(self.pending.drain());
       
   680             for (id, state) in cache.drain(..) {
       
   681                 match state {
       
   682                     NetworkClientState::NeedsRead => self.client_readable(poll, id)?,
       
   683                     NetworkClientState::NeedsWrite => self.client_writable(poll, id)?,
       
   684                     _ => {}
       
   685                 }
       
   686             }
       
   687             swap(&mut cache, &mut self.pending_cache);
       
   688         }
       
   689         Ok(())
       
   690     }
   248     }
   691 }
   249 }
   692 
   250 
   693 pub struct NetworkLayerBuilder {
   251 pub struct NetworkLayerBuilder {
   694     listener: Option<TcpListener>,
   252     listener: Option<TcpListener>,
   695     secure_listener: Option<TcpListener>,
       
   696     clients_capacity: usize,
   253     clients_capacity: usize,
   697     rooms_capacity: usize,
   254     rooms_capacity: usize,
   698 }
   255 }
   699 
   256 
   700 impl Default for NetworkLayerBuilder {
   257 impl Default for NetworkLayerBuilder {
   701     fn default() -> Self {
   258     fn default() -> Self {
   702         Self {
   259         Self {
   703             clients_capacity: 1024,
   260             clients_capacity: 1024,
   704             rooms_capacity: 512,
   261             rooms_capacity: 512,
   705             listener: None,
   262             listener: None,
   706             secure_listener: None,
       
   707         }
   263         }
   708     }
   264     }
   709 }
   265 }
   710 
   266 
   711 impl NetworkLayerBuilder {
   267 impl NetworkLayerBuilder {
   714             listener: Some(listener),
   270             listener: Some(listener),
   715             ..self
   271             ..self
   716         }
   272         }
   717     }
   273     }
   718 
   274 
   719     pub fn with_secure_listener(self, listener: TcpListener) -> Self {
   275     pub fn build(self) -> NetworkLayer {
   720         Self {
       
   721             secure_listener: Some(listener),
       
   722             ..self
       
   723         }
       
   724     }
       
   725 
       
   726     #[cfg(feature = "tls-connections")]
       
   727     fn create_ssl_context(listener: TcpListener) -> ServerSsl {
       
   728         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
       
   729         builder.set_verify(SslVerifyMode::NONE);
       
   730         builder.set_read_ahead(true);
       
   731         builder
       
   732             .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
       
   733             .expect("Cannot find certificate file");
       
   734         builder
       
   735             .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
       
   736             .expect("Cannot find private key file");
       
   737         builder.set_options(SslOptions::NO_COMPRESSION);
       
   738         builder.set_options(SslOptions::NO_TLSV1);
       
   739         builder.set_options(SslOptions::NO_TLSV1_1);
       
   740         builder.set_cipher_list("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384").unwrap();
       
   741         ServerSsl {
       
   742             listener,
       
   743             context: builder.build(),
       
   744         }
       
   745     }
       
   746 
       
   747     pub fn build(self, poll: &Poll) -> NetworkLayer {
       
   748         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
   276         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
   749 
   277 
   750         let clients = Slab::with_capacity(self.clients_capacity);
   278         let clients = Slab::with_capacity(self.clients_capacity);
   751         let pending = HashSet::with_capacity(2 * self.clients_capacity);
       
   752         let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
       
   753         let timeout_events = NetworkTimeoutEvents::new();
       
   754 
       
   755         #[cfg(feature = "official-server")]
       
   756         let waker = Waker::new(poll.registry(), utils::IO_TOKEN)
       
   757             .expect("Unable to create a waker for the IO thread");
       
   758 
   279 
   759         NetworkLayer {
   280         NetworkLayer {
   760             listener: self.listener.expect("No listener provided"),
   281             listener: self.listener.expect("No listener provided"),
   761             server_state,
   282             server_state,
   762             clients,
   283             clients,
   763             pending,
   284         }
   764             pending_cache,
   285     }
   765             #[cfg(feature = "tls-connections")]
   286 }
   766             ssl: Self::create_ssl_context(
       
   767                 self.secure_listener.expect("No secure listener provided"),
       
   768             ),
       
   769             #[cfg(feature = "official-server")]
       
   770             io: IoLayer::new(waker),
       
   771             timeout_events,
       
   772         }
       
   773     }
       
   774 }