diff -r f56936207a65 -r 8ddb5842fe0b rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Tue Apr 23 15:54:06 2019 +0200 +++ b/rust/hedgewars-server/src/server/network.rs Wed Apr 24 16:21:46 2019 +0300 @@ -11,7 +11,7 @@ use log::*; use mio::{ net::{TcpListener, TcpStream}, - Poll, PollOpt, Ready, Token, + Evented, Poll, PollOpt, Ready, Token, }; use mio_extras::timer; use netbuf; @@ -48,32 +48,29 @@ NeedsWrite, NeedsRead, Closed, + #[cfg(feature = "tls-connections")] + Connected, } type NetworkResult = io::Result<(T, NetworkClientState)>; -#[cfg(not(feature = "tls-connections"))] pub enum ClientSocket { Plain(TcpStream), -} - -#[cfg(feature = "tls-connections")] -pub enum ClientSocket { + #[cfg(feature = "tls-connections")] SslHandshake(Option>), + #[cfg(feature = "tls-connections")] SslStream(SslStream), } impl ClientSocket { fn inner(&self) -> &TcpStream { - #[cfg(not(feature = "tls-connections"))] match self { ClientSocket::Plain(stream) => stream, - } - - #[cfg(feature = "tls-connections")] - match self { + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(None) => unreachable!(), + #[cfg(feature = "tls-connections")] ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), } } @@ -117,7 +114,7 @@ "TLS handshake with {} ({}) completed", self.id, self.peer_addr ); - Ok(NetworkClientState::Idle) + Ok(NetworkClientState::Connected) } Err(HandshakeError::WouldBlock(new_handshake)) => { self.socket = ClientSocket::SslHandshake(Some(new_handshake)); @@ -171,19 +168,16 @@ } pub fn read(&mut self) -> NetworkResult> { - #[cfg(not(feature = "tls-connections"))] match self.socket { ClientSocket::Plain(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } - } - - #[cfg(feature = "tls-connections")] - match self.socket { + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(ref mut handshake_opt) => { let handshake = std::mem::replace(handshake_opt, None).unwrap(); Ok((Vec::new(), self.handshake_impl(handshake)?)) } + #[cfg(feature = "tls-connections")] ClientSocket::SslStream(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } @@ -210,25 +204,18 @@ } pub fn write(&mut self) -> NetworkResult<()> { - let result = { - #[cfg(not(feature = "tls-connections"))] - match self.socket { - ClientSocket::Plain(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream) - } + let result = match self.socket { + ClientSocket::Plain(ref mut stream) => { + NetworkClient::write_impl(&mut self.buf_out, stream) } - #[cfg(feature = "tls-connections")] - { - match self.socket { - ClientSocket::SslHandshake(ref mut handshake_opt) => { - let handshake = std::mem::replace(handshake_opt, None).unwrap(); - Ok(((), self.handshake_impl(handshake)?)) - } - ClientSocket::SslStream(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream) - } - } + ClientSocket::SslHandshake(ref mut handshake_opt) => { + let handshake = std::mem::replace(handshake_opt, None).unwrap(); + Ok(((), self.handshake_impl(handshake)?)) + } + #[cfg(feature = "tls-connections")] + ClientSocket::SslStream(ref mut stream) => { + NetworkClient::write_impl(&mut self.buf_out, stream) } }; @@ -251,6 +238,7 @@ #[cfg(feature = "tls-connections")] struct ServerSsl { + listener: TcpListener, context: SslContext, } @@ -324,6 +312,10 @@ timer: timer::Timer, } +fn register_read(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> { + poll.register(evented, token, Ready::readable(), PollOpt::edge()) +} + fn create_ping_timeout( timer: &mut timer::Timer, probes_count: u8, @@ -343,29 +335,8 @@ } impl NetworkLayer { - pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { - let server = HWServer::new(clients_limit, rooms_limit); - let clients = Slab::with_capacity(clients_limit); - let pending = HashSet::with_capacity(2 * clients_limit); - let pending_cache = Vec::with_capacity(2 * clients_limit); - let timer = timer::Builder::default().build(); - - NetworkLayer { - listener, - server, - clients, - pending, - pending_cache, - #[cfg(feature = "tls-connections")] - ssl: NetworkLayer::create_ssl_context(), - #[cfg(feature = "official-server")] - io: IoLayer::new(), - timer, - } - } - #[cfg(feature = "tls-connections")] - fn create_ssl_context() -> ServerSsl { + fn create_ssl_context(listener: TcpListener) -> ServerSsl { let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); builder.set_read_ahead(true); @@ -378,24 +349,16 @@ builder.set_options(SslOptions::NO_COMPRESSION); builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); ServerSsl { + listener, context: builder.build(), } } - pub fn register_server(&self, poll: &Poll) -> io::Result<()> { - poll.register( - &self.listener, - utils::SERVER_TOKEN, - Ready::readable(), - PollOpt::edge(), - )?; - - poll.register( - &self.timer, - utils::TIMER_TOKEN, - Ready::readable(), - PollOpt::edge(), - )?; + pub fn register(&self, poll: &Poll) -> io::Result<()> { + register_read(poll, &self.listener, utils::SERVER_TOKEN)?; + #[cfg(feature = "tls-connections")] + register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?; + register_read(poll, &self.timer, utils::TIMER_TOKEN)?; #[cfg(feature = "official-server")] self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; @@ -448,6 +411,10 @@ } fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { + if response.is_empty() { + return; + } + debug!("{} pending server messages", response.len()); let output = response.extract_messages(&mut self.server); for (clients, message) in output { @@ -512,41 +479,41 @@ } fn create_client_socket(&self, socket: TcpStream) -> io::Result { - #[cfg(not(feature = "tls-connections"))] - { - Ok(ClientSocket::Plain(socket)) - } + Ok(ClientSocket::Plain(socket)) + } - #[cfg(feature = "tls-connections")] - { - let ssl = Ssl::new(&self.ssl.context).unwrap(); - let mut builder = SslStreamBuilder::new(ssl, socket); - builder.set_accept_state(); - match builder.handshake() { - Ok(stream) => Ok(ClientSocket::SslStream(stream)), - Err(HandshakeError::WouldBlock(stream)) => { - Ok(ClientSocket::SslHandshake(Some(stream))) - } - Err(e) => { - debug!("OpenSSL handshake failed: {}", e); - Err(Error::new(ErrorKind::Other, "Connection failure")) - } + #[cfg(feature = "tls-connections")] + fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result { + let ssl = Ssl::new(&self.ssl.context).unwrap(); + let mut builder = SslStreamBuilder::new(ssl, socket); + builder.set_accept_state(); + match builder.handshake() { + Ok(stream) => Ok(ClientSocket::SslStream(stream)), + Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))), + Err(e) => { + debug!("OpenSSL handshake failed: {}", e); + Err(Error::new(ErrorKind::Other, "Connection failure")) } } } - pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { + pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> { let (client_socket, addr) = self.listener.accept()?; info!("Connected: {}", addr); - let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr); - - let mut response = handlers::Response::new(client_id); - - handlers::handle_client_accept(&mut self.server, client_id, &mut response); - - if !response.is_empty() { - self.handle_response(response, poll); + match server_token { + utils::SERVER_TOKEN => { + let client_id = + self.register_client(poll, self.create_client_socket(client_socket)?, addr); + let mut response = handlers::Response::new(client_id); + handlers::handle_client_accept(&mut self.server, client_id, &mut response); + self.handle_response(response, poll); + } + #[cfg(feature = "tls-connections")] + utils::SECURE_SERVER_TOKEN => { + self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr); + } + _ => unreachable!(), } Ok(()) @@ -595,6 +562,12 @@ self.pending.insert((client_id, state)); } NetworkClientState::Closed => self.client_error(&poll, client_id)?, + #[cfg(feature = "tls-connections")] + NetworkClientState::Connected => { + let mut response = handlers::Response::new(client_id); + handlers::handle_client_accept(&mut self.server, client_id, &mut response); + self.handle_response(response, poll); + } _ => {} }; } @@ -606,9 +579,7 @@ )?, } - if !response.is_empty() { - self.handle_response(response, poll); - } + self.handle_response(response, poll); Ok(()) } @@ -663,3 +634,60 @@ Ok(()) } } + +pub struct NetworkLayerBuilder { + listener: Option, + secure_listener: Option, + clients_capacity: usize, + rooms_capacity: usize, +} + +impl Default for NetworkLayerBuilder { + fn default() -> Self { + Self { + clients_capacity: 1024, + rooms_capacity: 512, + listener: None, + secure_listener: None, + } + } +} + +impl NetworkLayerBuilder { + pub fn with_listener(self, listener: TcpListener) -> Self { + Self { + listener: Some(listener), + ..self + } + } + + pub fn with_secure_listener(self, listener: TcpListener) -> Self { + Self { + secure_listener: Some(listener), + ..self + } + } + + pub fn build(self) -> NetworkLayer { + let server = HWServer::new(self.clients_capacity, self.rooms_capacity); + let clients = Slab::with_capacity(self.clients_capacity); + let pending = HashSet::with_capacity(2 * self.clients_capacity); + let pending_cache = Vec::with_capacity(2 * self.clients_capacity); + let timer = timer::Builder::default().build(); + + NetworkLayer { + listener: self.listener.expect("No listener provided"), + server, + clients, + pending, + pending_cache, + #[cfg(feature = "tls-connections")] + ssl: NetworkLayer::create_ssl_context( + self.secure_listener.expect("No secure listener provided"), + ), + #[cfg(feature = "official-server")] + io: IoLayer::new(), + timer, + } + } +}