rust/hedgewars-server/src/server/network.rs
changeset 14830 8ddb5842fe0b
parent 14807 b2beb784e4b5
child 14835 57ed3981db20
--- a/rust/hedgewars-server/src/server/network.rs	Tue Apr 23 15:54:06 2019 +0200
+++ b/rust/hedgewars-server/src/server/network.rs	Wed Apr 24 16:21:46 2019 +0300
@@ -11,7 +11,7 @@
 use log::*;
 use mio::{
     net::{TcpListener, TcpStream},
-    Poll, PollOpt, Ready, Token,
+    Evented, Poll, PollOpt, Ready, Token,
 };
 use mio_extras::timer;
 use netbuf;
@@ -48,32 +48,29 @@
     NeedsWrite,
     NeedsRead,
     Closed,
+    #[cfg(feature = "tls-connections")]
+    Connected,
 }
 
 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
 
-#[cfg(not(feature = "tls-connections"))]
 pub enum ClientSocket {
     Plain(TcpStream),
-}
-
-#[cfg(feature = "tls-connections")]
-pub enum ClientSocket {
+    #[cfg(feature = "tls-connections")]
     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
+    #[cfg(feature = "tls-connections")]
     SslStream(SslStream<TcpStream>),
 }
 
 impl ClientSocket {
     fn inner(&self) -> &TcpStream {
-        #[cfg(not(feature = "tls-connections"))]
         match self {
             ClientSocket::Plain(stream) => stream,
-        }
-
-        #[cfg(feature = "tls-connections")]
-        match self {
+            #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+            #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(None) => unreachable!(),
+            #[cfg(feature = "tls-connections")]
             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
         }
     }
@@ -117,7 +114,7 @@
                     "TLS handshake with {} ({}) completed",
                     self.id, self.peer_addr
                 );
-                Ok(NetworkClientState::Idle)
+                Ok(NetworkClientState::Connected)
             }
             Err(HandshakeError::WouldBlock(new_handshake)) => {
                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
@@ -171,19 +168,16 @@
     }
 
     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
-        #[cfg(not(feature = "tls-connections"))]
         match self.socket {
             ClientSocket::Plain(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
-        }
-
-        #[cfg(feature = "tls-connections")]
-        match self.socket {
+            #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(ref mut handshake_opt) => {
                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
                 Ok((Vec::new(), self.handshake_impl(handshake)?))
             }
+            #[cfg(feature = "tls-connections")]
             ClientSocket::SslStream(ref mut stream) => {
                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
             }
@@ -210,25 +204,18 @@
     }
 
     pub fn write(&mut self) -> NetworkResult<()> {
-        let result = {
-            #[cfg(not(feature = "tls-connections"))]
-            match self.socket {
-                ClientSocket::Plain(ref mut stream) => {
-                    NetworkClient::write_impl(&mut self.buf_out, stream)
-                }
+        let result = match self.socket {
+            ClientSocket::Plain(ref mut stream) => {
+                NetworkClient::write_impl(&mut self.buf_out, stream)
             }
-
             #[cfg(feature = "tls-connections")]
-            {
-                match self.socket {
-                    ClientSocket::SslHandshake(ref mut handshake_opt) => {
-                        let handshake = std::mem::replace(handshake_opt, None).unwrap();
-                        Ok(((), self.handshake_impl(handshake)?))
-                    }
-                    ClientSocket::SslStream(ref mut stream) => {
-                        NetworkClient::write_impl(&mut self.buf_out, stream)
-                    }
-                }
+            ClientSocket::SslHandshake(ref mut handshake_opt) => {
+                let handshake = std::mem::replace(handshake_opt, None).unwrap();
+                Ok(((), self.handshake_impl(handshake)?))
+            }
+            #[cfg(feature = "tls-connections")]
+            ClientSocket::SslStream(ref mut stream) => {
+                NetworkClient::write_impl(&mut self.buf_out, stream)
             }
         };
 
@@ -251,6 +238,7 @@
 
 #[cfg(feature = "tls-connections")]
 struct ServerSsl {
+    listener: TcpListener,
     context: SslContext,
 }
 
@@ -324,6 +312,10 @@
     timer: timer::Timer<TimerData>,
 }
 
+fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
+    poll.register(evented, token, Ready::readable(), PollOpt::edge())
+}
+
 fn create_ping_timeout(
     timer: &mut timer::Timer<TimerData>,
     probes_count: u8,
@@ -343,29 +335,8 @@
 }
 
 impl NetworkLayer {
-    pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
-        let server = HWServer::new(clients_limit, rooms_limit);
-        let clients = Slab::with_capacity(clients_limit);
-        let pending = HashSet::with_capacity(2 * clients_limit);
-        let pending_cache = Vec::with_capacity(2 * clients_limit);
-        let timer = timer::Builder::default().build();
-
-        NetworkLayer {
-            listener,
-            server,
-            clients,
-            pending,
-            pending_cache,
-            #[cfg(feature = "tls-connections")]
-            ssl: NetworkLayer::create_ssl_context(),
-            #[cfg(feature = "official-server")]
-            io: IoLayer::new(),
-            timer,
-        }
-    }
-
     #[cfg(feature = "tls-connections")]
-    fn create_ssl_context() -> ServerSsl {
+    fn create_ssl_context(listener: TcpListener) -> ServerSsl {
         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
         builder.set_verify(SslVerifyMode::NONE);
         builder.set_read_ahead(true);
@@ -378,24 +349,16 @@
         builder.set_options(SslOptions::NO_COMPRESSION);
         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
         ServerSsl {
+            listener,
             context: builder.build(),
         }
     }
 
-    pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
-        poll.register(
-            &self.listener,
-            utils::SERVER_TOKEN,
-            Ready::readable(),
-            PollOpt::edge(),
-        )?;
-
-        poll.register(
-            &self.timer,
-            utils::TIMER_TOKEN,
-            Ready::readable(),
-            PollOpt::edge(),
-        )?;
+    pub fn register(&self, poll: &Poll) -> io::Result<()> {
+        register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
+        #[cfg(feature = "tls-connections")]
+        register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?;
+        register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
 
         #[cfg(feature = "official-server")]
         self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
@@ -448,6 +411,10 @@
     }
 
     fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
+        if response.is_empty() {
+            return;
+        }
+
         debug!("{} pending server messages", response.len());
         let output = response.extract_messages(&mut self.server);
         for (clients, message) in output {
@@ -512,41 +479,41 @@
     }
 
     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
-        #[cfg(not(feature = "tls-connections"))]
-        {
-            Ok(ClientSocket::Plain(socket))
-        }
+        Ok(ClientSocket::Plain(socket))
+    }
 
-        #[cfg(feature = "tls-connections")]
-        {
-            let ssl = Ssl::new(&self.ssl.context).unwrap();
-            let mut builder = SslStreamBuilder::new(ssl, socket);
-            builder.set_accept_state();
-            match builder.handshake() {
-                Ok(stream) => Ok(ClientSocket::SslStream(stream)),
-                Err(HandshakeError::WouldBlock(stream)) => {
-                    Ok(ClientSocket::SslHandshake(Some(stream)))
-                }
-                Err(e) => {
-                    debug!("OpenSSL handshake failed: {}", e);
-                    Err(Error::new(ErrorKind::Other, "Connection failure"))
-                }
+    #[cfg(feature = "tls-connections")]
+    fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
+        let ssl = Ssl::new(&self.ssl.context).unwrap();
+        let mut builder = SslStreamBuilder::new(ssl, socket);
+        builder.set_accept_state();
+        match builder.handshake() {
+            Ok(stream) => Ok(ClientSocket::SslStream(stream)),
+            Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))),
+            Err(e) => {
+                debug!("OpenSSL handshake failed: {}", e);
+                Err(Error::new(ErrorKind::Other, "Connection failure"))
             }
         }
     }
 
-    pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
+    pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> {
         let (client_socket, addr) = self.listener.accept()?;
         info!("Connected: {}", addr);
 
-        let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr);
-
-        let mut response = handlers::Response::new(client_id);
-
-        handlers::handle_client_accept(&mut self.server, client_id, &mut response);
-
-        if !response.is_empty() {
-            self.handle_response(response, poll);
+        match server_token {
+            utils::SERVER_TOKEN => {
+                let client_id =
+                    self.register_client(poll, self.create_client_socket(client_socket)?, addr);
+                let mut response = handlers::Response::new(client_id);
+                handlers::handle_client_accept(&mut self.server, client_id, &mut response);
+                self.handle_response(response, poll);
+            }
+            #[cfg(feature = "tls-connections")]
+            utils::SECURE_SERVER_TOKEN => {
+                self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr);
+            }
+            _ => unreachable!(),
         }
 
         Ok(())
@@ -595,6 +562,12 @@
                         self.pending.insert((client_id, state));
                     }
                     NetworkClientState::Closed => self.client_error(&poll, client_id)?,
+                    #[cfg(feature = "tls-connections")]
+                    NetworkClientState::Connected => {
+                        let mut response = handlers::Response::new(client_id);
+                        handlers::handle_client_accept(&mut self.server, client_id, &mut response);
+                        self.handle_response(response, poll);
+                    }
                     _ => {}
                 };
             }
@@ -606,9 +579,7 @@
             )?,
         }
 
-        if !response.is_empty() {
-            self.handle_response(response, poll);
-        }
+        self.handle_response(response, poll);
 
         Ok(())
     }
@@ -663,3 +634,60 @@
         Ok(())
     }
 }
+
+pub struct NetworkLayerBuilder {
+    listener: Option<TcpListener>,
+    secure_listener: Option<TcpListener>,
+    clients_capacity: usize,
+    rooms_capacity: usize,
+}
+
+impl Default for NetworkLayerBuilder {
+    fn default() -> Self {
+        Self {
+            clients_capacity: 1024,
+            rooms_capacity: 512,
+            listener: None,
+            secure_listener: None,
+        }
+    }
+}
+
+impl NetworkLayerBuilder {
+    pub fn with_listener(self, listener: TcpListener) -> Self {
+        Self {
+            listener: Some(listener),
+            ..self
+        }
+    }
+
+    pub fn with_secure_listener(self, listener: TcpListener) -> Self {
+        Self {
+            secure_listener: Some(listener),
+            ..self
+        }
+    }
+
+    pub fn build(self) -> NetworkLayer {
+        let server = HWServer::new(self.clients_capacity, self.rooms_capacity);
+        let clients = Slab::with_capacity(self.clients_capacity);
+        let pending = HashSet::with_capacity(2 * self.clients_capacity);
+        let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
+        let timer = timer::Builder::default().build();
+
+        NetworkLayer {
+            listener: self.listener.expect("No listener provided"),
+            server,
+            clients,
+            pending,
+            pending_cache,
+            #[cfg(feature = "tls-connections")]
+            ssl: NetworkLayer::create_ssl_context(
+                self.secure_listener.expect("No secure listener provided"),
+            ),
+            #[cfg(feature = "official-server")]
+            io: IoLayer::new(),
+            timer,
+        }
+    }
+}