rust/hedgewars-server/src/server/network.rs
changeset 14478 98ef2913ec73
parent 14436 06672690d71b
child 14692 455865ccd36c
--- a/rust/hedgewars-server/src/server/network.rs	Sun Dec 16 00:09:20 2018 +0100
+++ b/rust/hedgewars-server/src/server/network.rs	Sun Dec 16 00:12:29 2018 +0100
@@ -1,37 +1,33 @@
 extern crate slab;
 
 use std::{
-    io, io::{Error, ErrorKind, Read, Write},
-    net::{SocketAddr, IpAddr, Ipv4Addr},
     collections::HashSet,
-    mem::{swap, replace}
+    io,
+    io::{Error, ErrorKind, Read, Write},
+    mem::{replace, swap},
+    net::{IpAddr, Ipv4Addr, SocketAddr},
 };
 
+use log::*;
 use mio::{
-    net::{TcpStream, TcpListener},
-    Poll, PollOpt, Ready, Token
+    net::{TcpListener, TcpStream},
+    Poll, PollOpt, Ready, Token,
 };
 use netbuf;
 use slab::Slab;
-use log::*;
 
+use super::{core::HWServer, coretypes::ClientId, io::FileServerIO};
 use crate::{
+    protocol::{messages::*, ProtocolDecoder},
     utils,
-    protocol::{ProtocolDecoder, messages::*}
-};
-use super::{
-    io::FileServerIO,
-    core::{HWServer},
-    coretypes::ClientId
 };
 #[cfg(feature = "tls-connections")]
 use openssl::{
+    error::ErrorStack,
     ssl::{
-        SslMethod, SslContext, Ssl, SslContextBuilder,
-        SslVerifyMode, SslFiletype, SslOptions,
-        SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
+        HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
+        SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
     },
-    error::ErrorStack
 };
 
 const MAX_BYTES_PER_READ: usize = 2048;
@@ -48,13 +44,13 @@
 
 #[cfg(not(feature = "tls-connections"))]
 pub enum ClientSocket {
-    Plain(TcpStream)
+    Plain(TcpStream),
 }
 
 #[cfg(feature = "tls-connections")]
 pub enum ClientSocket {
     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
-    SslStream(SslStream<TcpStream>)
+    SslStream(SslStream<TcpStream>),
 }
 
 impl ClientSocket {
@@ -68,7 +64,7 @@
         match self {
             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
             ClientSocket::SslHandshake(None) => unreachable!(),
-            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
+            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
         }
     }
 }
@@ -78,24 +74,32 @@
     socket: ClientSocket,
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
-    buf_out: netbuf::Buf
+    buf_out: netbuf::Buf,
 }
 
 impl NetworkClient {
     pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
         NetworkClient {
-            id, socket, peer_addr,
+            id,
+            socket,
+            peer_addr,
             decoder: ProtocolDecoder::new(),
-            buf_out: netbuf::Buf::new()
+            buf_out: netbuf::Buf::new(),
         }
     }
 
     #[cfg(feature = "tls-connections")]
-    fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> {
+    fn handshake_impl(
+        &mut self,
+        handshake: MidHandshakeSslStream<TcpStream>,
+    ) -> io::Result<NetworkClientState> {
         match handshake.handshake() {
             Ok(stream) => {
                 self.socket = ClientSocket::SslStream(stream);
-                debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
+                debug!(
+                    "TLS handshake with {} ({}) completed",
+                    self.id, self.peer_addr
+                );
                 Ok(NetworkClientState::Idle)
             }
             Err(HandshakeError::WouldBlock(new_handshake)) => {
@@ -107,12 +111,16 @@
                 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
                 Err(Error::new(ErrorKind::Other, "Connection failure"))
             }
-            Err(HandshakeError::SetupFailure(_)) => unreachable!()
+            Err(HandshakeError::SetupFailure(_)) => unreachable!(),
         }
     }
 
-    fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
-                          id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
+    fn read_impl<R: Read>(
+        decoder: &mut ProtocolDecoder,
+        source: &mut R,
+        id: ClientId,
+        addr: &SocketAddr,
+    ) -> NetworkResult<Vec<HWProtocolMessage>> {
         let mut bytes_read = 0;
         let result = loop {
             match decoder.read_from(source) {
@@ -127,21 +135,19 @@
                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
                         };
                         break Ok(result);
-                    }
-                    else if bytes_read >= MAX_BYTES_PER_READ {
-                        break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
+                    } else if bytes_read >= MAX_BYTES_PER_READ {
+                        break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead));
                     }
                 }
                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
-                    let messages =  if bytes_read == 0 {
+                    let messages = if bytes_read == 0 {
                         Vec::new()
                     } else {
                         decoder.extract_messages()
                     };
                     break Ok((messages, NetworkClientState::Idle));
                 }
-                Err(error) =>
-                    break Err(error)
+                Err(error) => break Err(error),
             }
         };
         decoder.sweep();
@@ -151,8 +157,9 @@
     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
         #[cfg(not(feature = "tls-connections"))]
         match self.socket {
-            ClientSocket::Plain(ref mut stream) =>
-                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
+            ClientSocket::Plain(ref mut stream) => {
+                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+            }
         }
 
         #[cfg(feature = "tls-connections")]
@@ -160,24 +167,27 @@
             ClientSocket::SslHandshake(ref mut handshake_opt) => {
                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
                 Ok((Vec::new(), self.handshake_impl(handshake)?))
-            },
-            ClientSocket::SslStream(ref mut stream) =>
+            }
+            ClientSocket::SslStream(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+            }
         }
     }
 
     fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
         let result = loop {
             match buf_out.write_to(destination) {
-                Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
-                    break Ok(((), NetworkClientState::Idle)),
+                Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
+                    break Ok(((), NetworkClientState::Idle))
+                }
                 Ok(_) => (),
-                Err(ref error) if error.kind() == ErrorKind::Interrupted
-                    || error.kind() == ErrorKind::WouldBlock => {
+                Err(ref error)
+                    if error.kind() == ErrorKind::Interrupted
+                        || error.kind() == ErrorKind::WouldBlock =>
+                {
                     break Ok(((), NetworkClientState::NeedsWrite));
-                },
-                Err(error) =>
-                    break Err(error)
+                }
+                Err(error) => break Err(error),
             }
         };
         result
@@ -187,18 +197,21 @@
         let result = {
             #[cfg(not(feature = "tls-connections"))]
             match self.socket {
-                ClientSocket::Plain(ref mut stream) =>
+                ClientSocket::Plain(ref mut stream) => {
                     NetworkClient::write_impl(&mut self.buf_out, stream)
+                }
             }
 
-            #[cfg(feature = "tls-connections")] {
+            #[cfg(feature = "tls-connections")]
+            {
                 match self.socket {
                     ClientSocket::SslHandshake(ref mut handshake_opt) => {
                         let handshake = std::mem::replace(handshake_opt, None).unwrap();
                         Ok(((), self.handshake_impl(handshake)?))
                     }
-                    ClientSocket::SslStream(ref mut stream) =>
+                    ClientSocket::SslStream(ref mut stream) => {
                         NetworkClient::write_impl(&mut self.buf_out, stream)
+                    }
                 }
             }
         };
@@ -222,7 +235,7 @@
 
 #[cfg(feature = "tls-connections")]
 struct ServerSsl {
-    context: SslContext
+    context: SslContext,
 }
 
 pub struct NetworkLayer {
@@ -232,7 +245,7 @@
     pending: HashSet<(ClientId, NetworkClientState)>,
     pending_cache: Vec<(ClientId, NetworkClientState)>,
     #[cfg(feature = "tls-connections")]
-    ssl: ServerSsl
+    ssl: ServerSsl,
 }
 
 impl NetworkLayer {
@@ -243,9 +256,13 @@
         let pending_cache = Vec::with_capacity(2 * clients_limit);
 
         NetworkLayer {
-            listener, server, clients, pending, pending_cache,
+            listener,
+            server,
+            clients,
+            pending,
+            pending_cache,
             #[cfg(feature = "tls-connections")]
-            ssl: NetworkLayer::create_ssl_context()
+            ssl: NetworkLayer::create_ssl_context(),
         }
     }
 
@@ -254,16 +271,26 @@
         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
         builder.set_verify(SslVerifyMode::NONE);
         builder.set_read_ahead(true);
-        builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
-        builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
+        builder
+            .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
+            .unwrap();
+        builder
+            .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
+            .unwrap();
         builder.set_options(SslOptions::NO_COMPRESSION);
         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
-        ServerSsl { context: builder.build() }
+        ServerSsl {
+            context: builder.build(),
+        }
     }
 
     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
-        poll.register(&self.listener, utils::SERVER, Ready::readable(),
-                      PollOpt::edge())
+        poll.register(
+            &self.listener,
+            utils::SERVER,
+            Ready::readable(),
+            PollOpt::edge(),
+        )
     }
 
     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
@@ -279,11 +306,20 @@
         }
     }
 
-    fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
-        poll.register(client_socket.inner(), Token(id),
-                      Ready::readable() | Ready::writable(),
-                      PollOpt::edge())
-            .expect("could not register socket with event loop");
+    fn register_client(
+        &mut self,
+        poll: &Poll,
+        id: ClientId,
+        client_socket: ClientSocket,
+        addr: SocketAddr,
+    ) {
+        poll.register(
+            client_socket.inner(),
+            Token(id),
+            Ready::readable() | Ready::writable(),
+            PollOpt::edge(),
+        )
+        .expect("could not register socket with event loop");
 
         let entry = self.clients.vacant_entry();
         let client = NetworkClient::new(id, client_socket, addr);
@@ -299,26 +335,29 @@
             for client_id in clients {
                 if let Some(client) = self.clients.get_mut(client_id) {
                     client.send_string(&msg_string);
-                    self.pending.insert((client_id, NetworkClientState::NeedsWrite));
+                    self.pending
+                        .insert((client_id, NetworkClientState::NeedsWrite));
                 }
             }
         }
     }
 
     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
-        #[cfg(not(feature = "tls-connections"))] {
+        #[cfg(not(feature = "tls-connections"))]
+        {
             Ok(ClientSocket::Plain(socket))
         }
 
-        #[cfg(feature = "tls-connections")] {
+        #[cfg(feature = "tls-connections")]
+        {
             let ssl = Ssl::new(&self.ssl.context).unwrap();
             let mut builder = SslStreamBuilder::new(ssl, socket);
             builder.set_accept_state();
             match builder.handshake() {
-                Ok(stream) =>
-                    Ok(ClientSocket::SslStream(stream)),
-                Err(HandshakeError::WouldBlock(stream)) =>
-                    Ok(ClientSocket::SslHandshake(Some(stream))),
+                Ok(stream) => Ok(ClientSocket::SslStream(stream)),
+                Err(HandshakeError::WouldBlock(stream)) => {
+                    Ok(ClientSocket::SslHandshake(Some(stream)))
+                }
                 Err(e) => {
                     debug!("OpenSSL handshake failed: {}", e);
                     Err(Error::new(ErrorKind::Other, "Connection failure"))
@@ -332,13 +371,24 @@
         info!("Connected: {}", addr);
 
         let client_id = self.server.add_client();
-        self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
+        self.register_client(
+            poll,
+            client_id,
+            self.create_client_socket(client_socket)?,
+            addr,
+        );
         self.flush_server_messages();
 
         Ok(())
     }
 
-    fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> {
+    fn operation_failed(
+        &mut self,
+        poll: &Poll,
+        client_id: ClientId,
+        error: &Error,
+        msg: &str,
+    ) -> io::Result<()> {
         let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
             client.peer_addr
         } else {
@@ -348,15 +398,13 @@
         self.client_error(poll, client_id)
     }
 
-    pub fn client_readable(&mut self, poll: &Poll,
-                           client_id: ClientId) -> io::Result<()> {
-        let messages =
-            if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                client.read()
-            } else {
-                warn!("invalid readable client: {}", client_id);
-                Ok((Vec::new(), NetworkClientState::Idle))
-            };
+    pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
+        let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+            client.read()
+        } else {
+            warn!("invalid readable client: {}", client_id);
+            Ok((Vec::new(), NetworkClientState::Idle))
+        };
 
         match messages {
             Ok((messages, state)) => {
@@ -366,15 +414,17 @@
                 match state {
                     NetworkClientState::NeedsRead => {
                         self.pending.insert((client_id, state));
-                    },
-                    NetworkClientState::Closed =>
-                        self.client_error(&poll, client_id)?,
+                    }
+                    NetworkClientState::Closed => self.client_error(&poll, client_id)?,
                     _ => {}
                 };
             }
             Err(e) => self.operation_failed(
-                poll, client_id, &e,
-                "Error while reading from client socket")?
+                poll,
+                client_id,
+                &e,
+                "Error while reading from client socket",
+            )?,
         }
 
         self.flush_server_messages();
@@ -389,31 +439,28 @@
         Ok(())
     }
 
-    pub fn client_writable(&mut self, poll: &Poll,
-                           client_id: ClientId) -> io::Result<()> {
-        let result =
-            if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                client.write()
-            } else {
-                warn!("invalid writable client: {}", client_id);
-                Ok(((), NetworkClientState::Idle))
-            };
+    pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
+        let result = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+            client.write()
+        } else {
+            warn!("invalid writable client: {}", client_id);
+            Ok(((), NetworkClientState::Idle))
+        };
 
         match result {
             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
                 self.pending.insert((client_id, state));
-            },
+            }
             Ok(_) => {}
-            Err(e) => self.operation_failed(
-                poll, client_id, &e,
-                "Error while writing to client socket")?
+            Err(e) => {
+                self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
+            }
         }
 
         Ok(())
     }
 
-    pub fn client_error(&mut self, poll: &Poll,
-                        client_id: ClientId) -> io::Result<()> {
+    pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
         self.deregister_client(poll, client_id);
         self.server.client_lost(client_id);
 
@@ -430,10 +477,8 @@
             cache.extend(self.pending.drain());
             for (id, state) in cache.drain(..) {
                 match state {
-                    NetworkClientState::NeedsRead =>
-                        self.client_readable(poll, id)?,
-                    NetworkClientState::NeedsWrite =>
-                        self.client_writable(poll, id)?,
+                    NetworkClientState::NeedsRead => self.client_readable(poll, id)?,
+                    NetworkClientState::NeedsWrite => self.client_writable(poll, id)?,
                     _ => {}
                 }
             }