rust/hedgewars-server/src/server/network.rs
changeset 14830 8ddb5842fe0b
parent 14807 b2beb784e4b5
child 14835 57ed3981db20
equal deleted inserted replaced
14829:f56936207a65 14830:8ddb5842fe0b
     9 };
     9 };
    10 
    10 
    11 use log::*;
    11 use log::*;
    12 use mio::{
    12 use mio::{
    13     net::{TcpListener, TcpStream},
    13     net::{TcpListener, TcpStream},
    14     Poll, PollOpt, Ready, Token,
    14     Evented, Poll, PollOpt, Ready, Token,
    15 };
    15 };
    16 use mio_extras::timer;
    16 use mio_extras::timer;
    17 use netbuf;
    17 use netbuf;
    18 use slab::Slab;
    18 use slab::Slab;
    19 
    19 
    46 pub enum NetworkClientState {
    46 pub enum NetworkClientState {
    47     Idle,
    47     Idle,
    48     NeedsWrite,
    48     NeedsWrite,
    49     NeedsRead,
    49     NeedsRead,
    50     Closed,
    50     Closed,
       
    51     #[cfg(feature = "tls-connections")]
       
    52     Connected,
    51 }
    53 }
    52 
    54 
    53 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    55 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    54 
    56 
    55 #[cfg(not(feature = "tls-connections"))]
       
    56 pub enum ClientSocket {
    57 pub enum ClientSocket {
    57     Plain(TcpStream),
    58     Plain(TcpStream),
    58 }
    59     #[cfg(feature = "tls-connections")]
    59 
       
    60 #[cfg(feature = "tls-connections")]
       
    61 pub enum ClientSocket {
       
    62     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
    60     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
       
    61     #[cfg(feature = "tls-connections")]
    63     SslStream(SslStream<TcpStream>),
    62     SslStream(SslStream<TcpStream>),
    64 }
    63 }
    65 
    64 
    66 impl ClientSocket {
    65 impl ClientSocket {
    67     fn inner(&self) -> &TcpStream {
    66     fn inner(&self) -> &TcpStream {
    68         #[cfg(not(feature = "tls-connections"))]
       
    69         match self {
    67         match self {
    70             ClientSocket::Plain(stream) => stream,
    68             ClientSocket::Plain(stream) => stream,
    71         }
    69             #[cfg(feature = "tls-connections")]
    72 
       
    73         #[cfg(feature = "tls-connections")]
       
    74         match self {
       
    75             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
    70             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
       
    71             #[cfg(feature = "tls-connections")]
    76             ClientSocket::SslHandshake(None) => unreachable!(),
    72             ClientSocket::SslHandshake(None) => unreachable!(),
       
    73             #[cfg(feature = "tls-connections")]
    77             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
    74             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
    78         }
    75         }
    79     }
    76     }
    80 }
    77 }
    81 
    78 
   115                 self.socket = ClientSocket::SslStream(stream);
   112                 self.socket = ClientSocket::SslStream(stream);
   116                 debug!(
   113                 debug!(
   117                     "TLS handshake with {} ({}) completed",
   114                     "TLS handshake with {} ({}) completed",
   118                     self.id, self.peer_addr
   115                     self.id, self.peer_addr
   119                 );
   116                 );
   120                 Ok(NetworkClientState::Idle)
   117                 Ok(NetworkClientState::Connected)
   121             }
   118             }
   122             Err(HandshakeError::WouldBlock(new_handshake)) => {
   119             Err(HandshakeError::WouldBlock(new_handshake)) => {
   123                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   120                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   124                 Ok(NetworkClientState::Idle)
   121                 Ok(NetworkClientState::Idle)
   125             }
   122             }
   169         };
   166         };
   170         result
   167         result
   171     }
   168     }
   172 
   169 
   173     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
   170     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
   174         #[cfg(not(feature = "tls-connections"))]
       
   175         match self.socket {
   171         match self.socket {
   176             ClientSocket::Plain(ref mut stream) => {
   172             ClientSocket::Plain(ref mut stream) => {
   177                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   173                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   178             }
   174             }
   179         }
   175             #[cfg(feature = "tls-connections")]
   180 
       
   181         #[cfg(feature = "tls-connections")]
       
   182         match self.socket {
       
   183             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   176             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   184                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   177                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   185                 Ok((Vec::new(), self.handshake_impl(handshake)?))
   178                 Ok((Vec::new(), self.handshake_impl(handshake)?))
   186             }
   179             }
       
   180             #[cfg(feature = "tls-connections")]
   187             ClientSocket::SslStream(ref mut stream) => {
   181             ClientSocket::SslStream(ref mut stream) => {
   188                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   182                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   189             }
   183             }
   190         }
   184         }
   191     }
   185     }
   208         };
   202         };
   209         result
   203         result
   210     }
   204     }
   211 
   205 
   212     pub fn write(&mut self) -> NetworkResult<()> {
   206     pub fn write(&mut self) -> NetworkResult<()> {
   213         let result = {
   207         let result = match self.socket {
   214             #[cfg(not(feature = "tls-connections"))]
   208             ClientSocket::Plain(ref mut stream) => {
   215             match self.socket {
   209                 NetworkClient::write_impl(&mut self.buf_out, stream)
   216                 ClientSocket::Plain(ref mut stream) => {
   210             }
   217                     NetworkClient::write_impl(&mut self.buf_out, stream)
   211             #[cfg(feature = "tls-connections")]
   218                 }
   212             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   219             }
   213                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   220 
   214                 Ok(((), self.handshake_impl(handshake)?))
   221             #[cfg(feature = "tls-connections")]
   215             }
   222             {
   216             #[cfg(feature = "tls-connections")]
   223                 match self.socket {
   217             ClientSocket::SslStream(ref mut stream) => {
   224                     ClientSocket::SslHandshake(ref mut handshake_opt) => {
   218                 NetworkClient::write_impl(&mut self.buf_out, stream)
   225                         let handshake = std::mem::replace(handshake_opt, None).unwrap();
       
   226                         Ok(((), self.handshake_impl(handshake)?))
       
   227                     }
       
   228                     ClientSocket::SslStream(ref mut stream) => {
       
   229                         NetworkClient::write_impl(&mut self.buf_out, stream)
       
   230                     }
       
   231                 }
       
   232             }
   219             }
   233         };
   220         };
   234 
   221 
   235         self.socket.inner().flush()?;
   222         self.socket.inner().flush()?;
   236         result
   223         result
   249     }
   236     }
   250 }
   237 }
   251 
   238 
   252 #[cfg(feature = "tls-connections")]
   239 #[cfg(feature = "tls-connections")]
   253 struct ServerSsl {
   240 struct ServerSsl {
       
   241     listener: TcpListener,
   254     context: SslContext,
   242     context: SslContext,
   255 }
   243 }
   256 
   244 
   257 #[cfg(feature = "official-server")]
   245 #[cfg(feature = "official-server")]
   258 pub struct IoLayer {
   246 pub struct IoLayer {
   322     #[cfg(feature = "official-server")]
   310     #[cfg(feature = "official-server")]
   323     io: IoLayer,
   311     io: IoLayer,
   324     timer: timer::Timer<TimerData>,
   312     timer: timer::Timer<TimerData>,
   325 }
   313 }
   326 
   314 
       
   315 fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
       
   316     poll.register(evented, token, Ready::readable(), PollOpt::edge())
       
   317 }
       
   318 
   327 fn create_ping_timeout(
   319 fn create_ping_timeout(
   328     timer: &mut timer::Timer<TimerData>,
   320     timer: &mut timer::Timer<TimerData>,
   329     probes_count: u8,
   321     probes_count: u8,
   330     client_id: ClientId,
   322     client_id: ClientId,
   331 ) -> timer::Timeout {
   323 ) -> timer::Timeout {
   341         TimerData(TimeoutEvent::DropClient, client_id),
   333         TimerData(TimeoutEvent::DropClient, client_id),
   342     )
   334     )
   343 }
   335 }
   344 
   336 
   345 impl NetworkLayer {
   337 impl NetworkLayer {
   346     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
       
   347         let server = HWServer::new(clients_limit, rooms_limit);
       
   348         let clients = Slab::with_capacity(clients_limit);
       
   349         let pending = HashSet::with_capacity(2 * clients_limit);
       
   350         let pending_cache = Vec::with_capacity(2 * clients_limit);
       
   351         let timer = timer::Builder::default().build();
       
   352 
       
   353         NetworkLayer {
       
   354             listener,
       
   355             server,
       
   356             clients,
       
   357             pending,
       
   358             pending_cache,
       
   359             #[cfg(feature = "tls-connections")]
       
   360             ssl: NetworkLayer::create_ssl_context(),
       
   361             #[cfg(feature = "official-server")]
       
   362             io: IoLayer::new(),
       
   363             timer,
       
   364         }
       
   365     }
       
   366 
       
   367     #[cfg(feature = "tls-connections")]
   338     #[cfg(feature = "tls-connections")]
   368     fn create_ssl_context() -> ServerSsl {
   339     fn create_ssl_context(listener: TcpListener) -> ServerSsl {
   369         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
   340         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
   370         builder.set_verify(SslVerifyMode::NONE);
   341         builder.set_verify(SslVerifyMode::NONE);
   371         builder.set_read_ahead(true);
   342         builder.set_read_ahead(true);
   372         builder
   343         builder
   373             .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
   344             .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
   376             .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
   347             .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
   377             .unwrap();
   348             .unwrap();
   378         builder.set_options(SslOptions::NO_COMPRESSION);
   349         builder.set_options(SslOptions::NO_COMPRESSION);
   379         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
   350         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
   380         ServerSsl {
   351         ServerSsl {
       
   352             listener,
   381             context: builder.build(),
   353             context: builder.build(),
   382         }
   354         }
   383     }
   355     }
   384 
   356 
   385     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
   357     pub fn register(&self, poll: &Poll) -> io::Result<()> {
   386         poll.register(
   358         register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
   387             &self.listener,
   359         #[cfg(feature = "tls-connections")]
   388             utils::SERVER_TOKEN,
   360         register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?;
   389             Ready::readable(),
   361         register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
   390             PollOpt::edge(),
       
   391         )?;
       
   392 
       
   393         poll.register(
       
   394             &self.timer,
       
   395             utils::TIMER_TOKEN,
       
   396             Ready::readable(),
       
   397             PollOpt::edge(),
       
   398         )?;
       
   399 
   362 
   400         #[cfg(feature = "official-server")]
   363         #[cfg(feature = "official-server")]
   401         self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
   364         self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
   402 
   365 
   403         Ok(())
   366         Ok(())
   446 
   409 
   447         client_id
   410         client_id
   448     }
   411     }
   449 
   412 
   450     fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
   413     fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
       
   414         if response.is_empty() {
       
   415             return;
       
   416         }
       
   417 
   451         debug!("{} pending server messages", response.len());
   418         debug!("{} pending server messages", response.len());
   452         let output = response.extract_messages(&mut self.server);
   419         let output = response.extract_messages(&mut self.server);
   453         for (clients, message) in output {
   420         for (clients, message) in output {
   454             debug!("Message {:?} to {:?}", message, clients);
   421             debug!("Message {:?} to {:?}", message, clients);
   455             let msg_string = message.to_raw_protocol();
   422             let msg_string = message.to_raw_protocol();
   510             handlers::handle_io_result(&mut self.server, client_id, &mut response, result);
   477             handlers::handle_io_result(&mut self.server, client_id, &mut response, result);
   511         }
   478         }
   512     }
   479     }
   513 
   480 
   514     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
   481     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
   515         #[cfg(not(feature = "tls-connections"))]
   482         Ok(ClientSocket::Plain(socket))
   516         {
   483     }
   517             Ok(ClientSocket::Plain(socket))
   484 
   518         }
   485     #[cfg(feature = "tls-connections")]
   519 
   486     fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
   520         #[cfg(feature = "tls-connections")]
   487         let ssl = Ssl::new(&self.ssl.context).unwrap();
   521         {
   488         let mut builder = SslStreamBuilder::new(ssl, socket);
   522             let ssl = Ssl::new(&self.ssl.context).unwrap();
   489         builder.set_accept_state();
   523             let mut builder = SslStreamBuilder::new(ssl, socket);
   490         match builder.handshake() {
   524             builder.set_accept_state();
   491             Ok(stream) => Ok(ClientSocket::SslStream(stream)),
   525             match builder.handshake() {
   492             Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))),
   526                 Ok(stream) => Ok(ClientSocket::SslStream(stream)),
   493             Err(e) => {
   527                 Err(HandshakeError::WouldBlock(stream)) => {
   494                 debug!("OpenSSL handshake failed: {}", e);
   528                     Ok(ClientSocket::SslHandshake(Some(stream)))
   495                 Err(Error::new(ErrorKind::Other, "Connection failure"))
   529                 }
   496             }
   530                 Err(e) => {
   497         }
   531                     debug!("OpenSSL handshake failed: {}", e);
   498     }
   532                     Err(Error::new(ErrorKind::Other, "Connection failure"))
   499 
   533                 }
   500     pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> {
   534             }
       
   535         }
       
   536     }
       
   537 
       
   538     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
       
   539         let (client_socket, addr) = self.listener.accept()?;
   501         let (client_socket, addr) = self.listener.accept()?;
   540         info!("Connected: {}", addr);
   502         info!("Connected: {}", addr);
   541 
   503 
   542         let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr);
   504         match server_token {
   543 
   505             utils::SERVER_TOKEN => {
   544         let mut response = handlers::Response::new(client_id);
   506                 let client_id =
   545 
   507                     self.register_client(poll, self.create_client_socket(client_socket)?, addr);
   546         handlers::handle_client_accept(&mut self.server, client_id, &mut response);
   508                 let mut response = handlers::Response::new(client_id);
   547 
   509                 handlers::handle_client_accept(&mut self.server, client_id, &mut response);
   548         if !response.is_empty() {
   510                 self.handle_response(response, poll);
   549             self.handle_response(response, poll);
   511             }
       
   512             #[cfg(feature = "tls-connections")]
       
   513             utils::SECURE_SERVER_TOKEN => {
       
   514                 self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr);
       
   515             }
       
   516             _ => unreachable!(),
   550         }
   517         }
   551 
   518 
   552         Ok(())
   519         Ok(())
   553     }
   520     }
   554 
   521 
   593                 match state {
   560                 match state {
   594                     NetworkClientState::NeedsRead => {
   561                     NetworkClientState::NeedsRead => {
   595                         self.pending.insert((client_id, state));
   562                         self.pending.insert((client_id, state));
   596                     }
   563                     }
   597                     NetworkClientState::Closed => self.client_error(&poll, client_id)?,
   564                     NetworkClientState::Closed => self.client_error(&poll, client_id)?,
       
   565                     #[cfg(feature = "tls-connections")]
       
   566                     NetworkClientState::Connected => {
       
   567                         let mut response = handlers::Response::new(client_id);
       
   568                         handlers::handle_client_accept(&mut self.server, client_id, &mut response);
       
   569                         self.handle_response(response, poll);
       
   570                     }
   598                     _ => {}
   571                     _ => {}
   599                 };
   572                 };
   600             }
   573             }
   601             Err(e) => self.operation_failed(
   574             Err(e) => self.operation_failed(
   602                 poll,
   575                 poll,
   604                 &e,
   577                 &e,
   605                 "Error while reading from client socket",
   578                 "Error while reading from client socket",
   606             )?,
   579             )?,
   607         }
   580         }
   608 
   581 
   609         if !response.is_empty() {
   582         self.handle_response(response, poll);
   610             self.handle_response(response, poll);
       
   611         }
       
   612 
   583 
   613         Ok(())
   584         Ok(())
   614     }
   585     }
   615 
   586 
   616     pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   587     pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   661             swap(&mut cache, &mut self.pending_cache);
   632             swap(&mut cache, &mut self.pending_cache);
   662         }
   633         }
   663         Ok(())
   634         Ok(())
   664     }
   635     }
   665 }
   636 }
       
   637 
       
   638 pub struct NetworkLayerBuilder {
       
   639     listener: Option<TcpListener>,
       
   640     secure_listener: Option<TcpListener>,
       
   641     clients_capacity: usize,
       
   642     rooms_capacity: usize,
       
   643 }
       
   644 
       
   645 impl Default for NetworkLayerBuilder {
       
   646     fn default() -> Self {
       
   647         Self {
       
   648             clients_capacity: 1024,
       
   649             rooms_capacity: 512,
       
   650             listener: None,
       
   651             secure_listener: None,
       
   652         }
       
   653     }
       
   654 }
       
   655 
       
   656 impl NetworkLayerBuilder {
       
   657     pub fn with_listener(self, listener: TcpListener) -> Self {
       
   658         Self {
       
   659             listener: Some(listener),
       
   660             ..self
       
   661         }
       
   662     }
       
   663 
       
   664     pub fn with_secure_listener(self, listener: TcpListener) -> Self {
       
   665         Self {
       
   666             secure_listener: Some(listener),
       
   667             ..self
       
   668         }
       
   669     }
       
   670 
       
   671     pub fn build(self) -> NetworkLayer {
       
   672         let server = HWServer::new(self.clients_capacity, self.rooms_capacity);
       
   673         let clients = Slab::with_capacity(self.clients_capacity);
       
   674         let pending = HashSet::with_capacity(2 * self.clients_capacity);
       
   675         let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
       
   676         let timer = timer::Builder::default().build();
       
   677 
       
   678         NetworkLayer {
       
   679             listener: self.listener.expect("No listener provided"),
       
   680             server,
       
   681             clients,
       
   682             pending,
       
   683             pending_cache,
       
   684             #[cfg(feature = "tls-connections")]
       
   685             ssl: NetworkLayer::create_ssl_context(
       
   686                 self.secure_listener.expect("No secure listener provided"),
       
   687             ),
       
   688             #[cfg(feature = "official-server")]
       
   689             io: IoLayer::new(),
       
   690             timer,
       
   691         }
       
   692     }
       
   693 }