rust/hedgewars-server/src/server/network.rs
changeset 14436 06672690d71b
parent 14413 e335b3120f59
child 14478 98ef2913ec73
equal deleted inserted replaced
14435:6843c4551cde 14436:06672690d71b
       
     1 extern crate slab;
       
     2 
       
     3 use std::{
       
     4     io, io::{Error, ErrorKind, Read, Write},
       
     5     net::{SocketAddr, IpAddr, Ipv4Addr},
       
     6     collections::HashSet,
       
     7     mem::{swap, replace}
       
     8 };
       
     9 
       
    10 use mio::{
       
    11     net::{TcpStream, TcpListener},
       
    12     Poll, PollOpt, Ready, Token
       
    13 };
       
    14 use netbuf;
       
    15 use slab::Slab;
       
    16 use log::*;
       
    17 
       
    18 use crate::{
       
    19     utils,
       
    20     protocol::{ProtocolDecoder, messages::*}
       
    21 };
       
    22 use super::{
       
    23     io::FileServerIO,
       
    24     core::{HWServer},
       
    25     coretypes::ClientId
       
    26 };
       
    27 #[cfg(feature = "tls-connections")]
       
    28 use openssl::{
       
    29     ssl::{
       
    30         SslMethod, SslContext, Ssl, SslContextBuilder,
       
    31         SslVerifyMode, SslFiletype, SslOptions,
       
    32         SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
       
    33     },
       
    34     error::ErrorStack
       
    35 };
       
    36 
       
    37 const MAX_BYTES_PER_READ: usize = 2048;
       
    38 
       
    39 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
       
    40 pub enum NetworkClientState {
       
    41     Idle,
       
    42     NeedsWrite,
       
    43     NeedsRead,
       
    44     Closed,
       
    45 }
       
    46 
       
    47 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
       
    48 
       
    49 #[cfg(not(feature = "tls-connections"))]
       
    50 pub enum ClientSocket {
       
    51     Plain(TcpStream)
       
    52 }
       
    53 
       
    54 #[cfg(feature = "tls-connections")]
       
    55 pub enum ClientSocket {
       
    56     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
       
    57     SslStream(SslStream<TcpStream>)
       
    58 }
       
    59 
       
    60 impl ClientSocket {
       
    61     fn inner(&self) -> &TcpStream {
       
    62         #[cfg(not(feature = "tls-connections"))]
       
    63         match self {
       
    64             ClientSocket::Plain(stream) => stream,
       
    65         }
       
    66 
       
    67         #[cfg(feature = "tls-connections")]
       
    68         match self {
       
    69             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
       
    70             ClientSocket::SslHandshake(None) => unreachable!(),
       
    71             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
       
    72         }
       
    73     }
       
    74 }
       
    75 
       
    76 pub struct NetworkClient {
       
    77     id: ClientId,
       
    78     socket: ClientSocket,
       
    79     peer_addr: SocketAddr,
       
    80     decoder: ProtocolDecoder,
       
    81     buf_out: netbuf::Buf
       
    82 }
       
    83 
       
    84 impl NetworkClient {
       
    85     pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
       
    86         NetworkClient {
       
    87             id, socket, peer_addr,
       
    88             decoder: ProtocolDecoder::new(),
       
    89             buf_out: netbuf::Buf::new()
       
    90         }
       
    91     }
       
    92 
       
    93     #[cfg(feature = "tls-connections")]
       
    94     fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> {
       
    95         match handshake.handshake() {
       
    96             Ok(stream) => {
       
    97                 self.socket = ClientSocket::SslStream(stream);
       
    98                 debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
       
    99                 Ok(NetworkClientState::Idle)
       
   100             }
       
   101             Err(HandshakeError::WouldBlock(new_handshake)) => {
       
   102                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
       
   103                 Ok(NetworkClientState::Idle)
       
   104             }
       
   105             Err(HandshakeError::Failure(new_handshake)) => {
       
   106                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
       
   107                 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
       
   108                 Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   109             }
       
   110             Err(HandshakeError::SetupFailure(_)) => unreachable!()
       
   111         }
       
   112     }
       
   113 
       
   114     fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
       
   115                           id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
       
   116         let mut bytes_read = 0;
       
   117         let result = loop {
       
   118             match decoder.read_from(source) {
       
   119                 Ok(bytes) => {
       
   120                     debug!("Client {}: read {} bytes", id, bytes);
       
   121                     bytes_read += bytes;
       
   122                     if bytes == 0 {
       
   123                         let result = if bytes_read == 0 {
       
   124                             info!("EOF for client {} ({})", id, addr);
       
   125                             (Vec::new(), NetworkClientState::Closed)
       
   126                         } else {
       
   127                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
       
   128                         };
       
   129                         break Ok(result);
       
   130                     }
       
   131                     else if bytes_read >= MAX_BYTES_PER_READ {
       
   132                         break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
       
   133                     }
       
   134                 }
       
   135                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
       
   136                     let messages =  if bytes_read == 0 {
       
   137                         Vec::new()
       
   138                     } else {
       
   139                         decoder.extract_messages()
       
   140                     };
       
   141                     break Ok((messages, NetworkClientState::Idle));
       
   142                 }
       
   143                 Err(error) =>
       
   144                     break Err(error)
       
   145             }
       
   146         };
       
   147         decoder.sweep();
       
   148         result
       
   149     }
       
   150 
       
   151     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
       
   152         #[cfg(not(feature = "tls-connections"))]
       
   153         match self.socket {
       
   154             ClientSocket::Plain(ref mut stream) =>
       
   155                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
       
   156         }
       
   157 
       
   158         #[cfg(feature = "tls-connections")]
       
   159         match self.socket {
       
   160             ClientSocket::SslHandshake(ref mut handshake_opt) => {
       
   161                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
       
   162                 Ok((Vec::new(), self.handshake_impl(handshake)?))
       
   163             },
       
   164             ClientSocket::SslStream(ref mut stream) =>
       
   165                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
       
   166         }
       
   167     }
       
   168 
       
   169     fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
       
   170         let result = loop {
       
   171             match buf_out.write_to(destination) {
       
   172                 Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
       
   173                     break Ok(((), NetworkClientState::Idle)),
       
   174                 Ok(_) => (),
       
   175                 Err(ref error) if error.kind() == ErrorKind::Interrupted
       
   176                     || error.kind() == ErrorKind::WouldBlock => {
       
   177                     break Ok(((), NetworkClientState::NeedsWrite));
       
   178                 },
       
   179                 Err(error) =>
       
   180                     break Err(error)
       
   181             }
       
   182         };
       
   183         result
       
   184     }
       
   185 
       
   186     pub fn write(&mut self) -> NetworkResult<()> {
       
   187         let result = {
       
   188             #[cfg(not(feature = "tls-connections"))]
       
   189             match self.socket {
       
   190                 ClientSocket::Plain(ref mut stream) =>
       
   191                     NetworkClient::write_impl(&mut self.buf_out, stream)
       
   192             }
       
   193 
       
   194             #[cfg(feature = "tls-connections")] {
       
   195                 match self.socket {
       
   196                     ClientSocket::SslHandshake(ref mut handshake_opt) => {
       
   197                         let handshake = std::mem::replace(handshake_opt, None).unwrap();
       
   198                         Ok(((), self.handshake_impl(handshake)?))
       
   199                     }
       
   200                     ClientSocket::SslStream(ref mut stream) =>
       
   201                         NetworkClient::write_impl(&mut self.buf_out, stream)
       
   202                 }
       
   203             }
       
   204         };
       
   205 
       
   206         self.socket.inner().flush()?;
       
   207         result
       
   208     }
       
   209 
       
   210     pub fn send_raw_msg(&mut self, msg: &[u8]) {
       
   211         self.buf_out.write_all(msg).unwrap();
       
   212     }
       
   213 
       
   214     pub fn send_string(&mut self, msg: &str) {
       
   215         self.send_raw_msg(&msg.as_bytes());
       
   216     }
       
   217 
       
   218     pub fn send_msg(&mut self, msg: &HWServerMessage) {
       
   219         self.send_string(&msg.to_raw_protocol());
       
   220     }
       
   221 }
       
   222 
       
   223 #[cfg(feature = "tls-connections")]
       
   224 struct ServerSsl {
       
   225     context: SslContext
       
   226 }
       
   227 
       
   228 pub struct NetworkLayer {
       
   229     listener: TcpListener,
       
   230     server: HWServer,
       
   231     clients: Slab<NetworkClient>,
       
   232     pending: HashSet<(ClientId, NetworkClientState)>,
       
   233     pending_cache: Vec<(ClientId, NetworkClientState)>,
       
   234     #[cfg(feature = "tls-connections")]
       
   235     ssl: ServerSsl
       
   236 }
       
   237 
       
   238 impl NetworkLayer {
       
   239     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
       
   240         let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new()));
       
   241         let clients = Slab::with_capacity(clients_limit);
       
   242         let pending = HashSet::with_capacity(2 * clients_limit);
       
   243         let pending_cache = Vec::with_capacity(2 * clients_limit);
       
   244 
       
   245         NetworkLayer {
       
   246             listener, server, clients, pending, pending_cache,
       
   247             #[cfg(feature = "tls-connections")]
       
   248             ssl: NetworkLayer::create_ssl_context()
       
   249         }
       
   250     }
       
   251 
       
   252     #[cfg(feature = "tls-connections")]
       
   253     fn create_ssl_context() -> ServerSsl {
       
   254         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
       
   255         builder.set_verify(SslVerifyMode::NONE);
       
   256         builder.set_read_ahead(true);
       
   257         builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
       
   258         builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
       
   259         builder.set_options(SslOptions::NO_COMPRESSION);
       
   260         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
       
   261         ServerSsl { context: builder.build() }
       
   262     }
       
   263 
       
   264     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
       
   265         poll.register(&self.listener, utils::SERVER, Ready::readable(),
       
   266                       PollOpt::edge())
       
   267     }
       
   268 
       
   269     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
       
   270         let mut client_exists = false;
       
   271         if let Some(ref client) = self.clients.get(id) {
       
   272             poll.deregister(client.socket.inner())
       
   273                 .expect("could not deregister socket");
       
   274             info!("client {} ({}) removed", client.id, client.peer_addr);
       
   275             client_exists = true;
       
   276         }
       
   277         if client_exists {
       
   278             self.clients.remove(id);
       
   279         }
       
   280     }
       
   281 
       
   282     fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
       
   283         poll.register(client_socket.inner(), Token(id),
       
   284                       Ready::readable() | Ready::writable(),
       
   285                       PollOpt::edge())
       
   286             .expect("could not register socket with event loop");
       
   287 
       
   288         let entry = self.clients.vacant_entry();
       
   289         let client = NetworkClient::new(id, client_socket, addr);
       
   290         info!("client {} ({}) added", client.id, client.peer_addr);
       
   291         entry.insert(client);
       
   292     }
       
   293 
       
   294     fn flush_server_messages(&mut self) {
       
   295         debug!("{} pending server messages", self.server.output.len());
       
   296         for (clients, message) in self.server.output.drain(..) {
       
   297             debug!("Message {:?} to {:?}", message, clients);
       
   298             let msg_string = message.to_raw_protocol();
       
   299             for client_id in clients {
       
   300                 if let Some(client) = self.clients.get_mut(client_id) {
       
   301                     client.send_string(&msg_string);
       
   302                     self.pending.insert((client_id, NetworkClientState::NeedsWrite));
       
   303                 }
       
   304             }
       
   305         }
       
   306     }
       
   307 
       
   308     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
       
   309         #[cfg(not(feature = "tls-connections"))] {
       
   310             Ok(ClientSocket::Plain(socket))
       
   311         }
       
   312 
       
   313         #[cfg(feature = "tls-connections")] {
       
   314             let ssl = Ssl::new(&self.ssl.context).unwrap();
       
   315             let mut builder = SslStreamBuilder::new(ssl, socket);
       
   316             builder.set_accept_state();
       
   317             match builder.handshake() {
       
   318                 Ok(stream) =>
       
   319                     Ok(ClientSocket::SslStream(stream)),
       
   320                 Err(HandshakeError::WouldBlock(stream)) =>
       
   321                     Ok(ClientSocket::SslHandshake(Some(stream))),
       
   322                 Err(e) => {
       
   323                     debug!("OpenSSL handshake failed: {}", e);
       
   324                     Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   325                 }
       
   326             }
       
   327         }
       
   328     }
       
   329 
       
   330     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
       
   331         let (client_socket, addr) = self.listener.accept()?;
       
   332         info!("Connected: {}", addr);
       
   333 
       
   334         let client_id = self.server.add_client();
       
   335         self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
       
   336         self.flush_server_messages();
       
   337 
       
   338         Ok(())
       
   339     }
       
   340 
       
   341     fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> {
       
   342         let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
       
   343             client.peer_addr
       
   344         } else {
       
   345             SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
       
   346         };
       
   347         debug!("{}({}): {}", msg, addr, error);
       
   348         self.client_error(poll, client_id)
       
   349     }
       
   350 
       
   351     pub fn client_readable(&mut self, poll: &Poll,
       
   352                            client_id: ClientId) -> io::Result<()> {
       
   353         let messages =
       
   354             if let Some(ref mut client) = self.clients.get_mut(client_id) {
       
   355                 client.read()
       
   356             } else {
       
   357                 warn!("invalid readable client: {}", client_id);
       
   358                 Ok((Vec::new(), NetworkClientState::Idle))
       
   359             };
       
   360 
       
   361         match messages {
       
   362             Ok((messages, state)) => {
       
   363                 for message in messages {
       
   364                     self.server.handle_msg(client_id, message);
       
   365                 }
       
   366                 match state {
       
   367                     NetworkClientState::NeedsRead => {
       
   368                         self.pending.insert((client_id, state));
       
   369                     },
       
   370                     NetworkClientState::Closed =>
       
   371                         self.client_error(&poll, client_id)?,
       
   372                     _ => {}
       
   373                 };
       
   374             }
       
   375             Err(e) => self.operation_failed(
       
   376                 poll, client_id, &e,
       
   377                 "Error while reading from client socket")?
       
   378         }
       
   379 
       
   380         self.flush_server_messages();
       
   381 
       
   382         if !self.server.removed_clients.is_empty() {
       
   383             let ids: Vec<_> = self.server.removed_clients.drain(..).collect();
       
   384             for client_id in ids {
       
   385                 self.deregister_client(poll, client_id);
       
   386             }
       
   387         }
       
   388 
       
   389         Ok(())
       
   390     }
       
   391 
       
   392     pub fn client_writable(&mut self, poll: &Poll,
       
   393                            client_id: ClientId) -> io::Result<()> {
       
   394         let result =
       
   395             if let Some(ref mut client) = self.clients.get_mut(client_id) {
       
   396                 client.write()
       
   397             } else {
       
   398                 warn!("invalid writable client: {}", client_id);
       
   399                 Ok(((), NetworkClientState::Idle))
       
   400             };
       
   401 
       
   402         match result {
       
   403             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
       
   404                 self.pending.insert((client_id, state));
       
   405             },
       
   406             Ok(_) => {}
       
   407             Err(e) => self.operation_failed(
       
   408                 poll, client_id, &e,
       
   409                 "Error while writing to client socket")?
       
   410         }
       
   411 
       
   412         Ok(())
       
   413     }
       
   414 
       
   415     pub fn client_error(&mut self, poll: &Poll,
       
   416                         client_id: ClientId) -> io::Result<()> {
       
   417         self.deregister_client(poll, client_id);
       
   418         self.server.client_lost(client_id);
       
   419 
       
   420         Ok(())
       
   421     }
       
   422 
       
   423     pub fn has_pending_operations(&self) -> bool {
       
   424         !self.pending.is_empty()
       
   425     }
       
   426 
       
   427     pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
       
   428         if self.has_pending_operations() {
       
   429             let mut cache = replace(&mut self.pending_cache, Vec::new());
       
   430             cache.extend(self.pending.drain());
       
   431             for (id, state) in cache.drain(..) {
       
   432                 match state {
       
   433                     NetworkClientState::NeedsRead =>
       
   434                         self.client_readable(poll, id)?,
       
   435                     NetworkClientState::NeedsWrite =>
       
   436                         self.client_writable(poll, id)?,
       
   437                     _ => {}
       
   438                 }
       
   439             }
       
   440             swap(&mut cache, &mut self.pending_cache);
       
   441         }
       
   442         Ok(())
       
   443     }
       
   444 }