rust/hedgewars-server/src/server/network.rs
changeset 14457 98ef2913ec73
parent 14415 06672690d71b
child 14671 455865ccd36c
equal deleted inserted replaced
14456:a077aac9df01 14457:98ef2913ec73
     1 extern crate slab;
     1 extern crate slab;
     2 
     2 
     3 use std::{
     3 use std::{
     4     io, io::{Error, ErrorKind, Read, Write},
       
     5     net::{SocketAddr, IpAddr, Ipv4Addr},
       
     6     collections::HashSet,
     4     collections::HashSet,
     7     mem::{swap, replace}
     5     io,
       
     6     io::{Error, ErrorKind, Read, Write},
       
     7     mem::{replace, swap},
       
     8     net::{IpAddr, Ipv4Addr, SocketAddr},
     8 };
     9 };
     9 
    10 
       
    11 use log::*;
    10 use mio::{
    12 use mio::{
    11     net::{TcpStream, TcpListener},
    13     net::{TcpListener, TcpStream},
    12     Poll, PollOpt, Ready, Token
    14     Poll, PollOpt, Ready, Token,
    13 };
    15 };
    14 use netbuf;
    16 use netbuf;
    15 use slab::Slab;
    17 use slab::Slab;
    16 use log::*;
    18 
    17 
    19 use super::{core::HWServer, coretypes::ClientId, io::FileServerIO};
    18 use crate::{
    20 use crate::{
       
    21     protocol::{messages::*, ProtocolDecoder},
    19     utils,
    22     utils,
    20     protocol::{ProtocolDecoder, messages::*}
       
    21 };
       
    22 use super::{
       
    23     io::FileServerIO,
       
    24     core::{HWServer},
       
    25     coretypes::ClientId
       
    26 };
    23 };
    27 #[cfg(feature = "tls-connections")]
    24 #[cfg(feature = "tls-connections")]
    28 use openssl::{
    25 use openssl::{
       
    26     error::ErrorStack,
    29     ssl::{
    27     ssl::{
    30         SslMethod, SslContext, Ssl, SslContextBuilder,
    28         HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
    31         SslVerifyMode, SslFiletype, SslOptions,
    29         SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
    32         SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
       
    33     },
    30     },
    34     error::ErrorStack
       
    35 };
    31 };
    36 
    32 
    37 const MAX_BYTES_PER_READ: usize = 2048;
    33 const MAX_BYTES_PER_READ: usize = 2048;
    38 
    34 
    39 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
    35 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
    46 
    42 
    47 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    43 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    48 
    44 
    49 #[cfg(not(feature = "tls-connections"))]
    45 #[cfg(not(feature = "tls-connections"))]
    50 pub enum ClientSocket {
    46 pub enum ClientSocket {
    51     Plain(TcpStream)
    47     Plain(TcpStream),
    52 }
    48 }
    53 
    49 
    54 #[cfg(feature = "tls-connections")]
    50 #[cfg(feature = "tls-connections")]
    55 pub enum ClientSocket {
    51 pub enum ClientSocket {
    56     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
    52     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
    57     SslStream(SslStream<TcpStream>)
    53     SslStream(SslStream<TcpStream>),
    58 }
    54 }
    59 
    55 
    60 impl ClientSocket {
    56 impl ClientSocket {
    61     fn inner(&self) -> &TcpStream {
    57     fn inner(&self) -> &TcpStream {
    62         #[cfg(not(feature = "tls-connections"))]
    58         #[cfg(not(feature = "tls-connections"))]
    66 
    62 
    67         #[cfg(feature = "tls-connections")]
    63         #[cfg(feature = "tls-connections")]
    68         match self {
    64         match self {
    69             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
    65             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
    70             ClientSocket::SslHandshake(None) => unreachable!(),
    66             ClientSocket::SslHandshake(None) => unreachable!(),
    71             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
    67             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
    72         }
    68         }
    73     }
    69     }
    74 }
    70 }
    75 
    71 
    76 pub struct NetworkClient {
    72 pub struct NetworkClient {
    77     id: ClientId,
    73     id: ClientId,
    78     socket: ClientSocket,
    74     socket: ClientSocket,
    79     peer_addr: SocketAddr,
    75     peer_addr: SocketAddr,
    80     decoder: ProtocolDecoder,
    76     decoder: ProtocolDecoder,
    81     buf_out: netbuf::Buf
    77     buf_out: netbuf::Buf,
    82 }
    78 }
    83 
    79 
    84 impl NetworkClient {
    80 impl NetworkClient {
    85     pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
    81     pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
    86         NetworkClient {
    82         NetworkClient {
    87             id, socket, peer_addr,
    83             id,
       
    84             socket,
       
    85             peer_addr,
    88             decoder: ProtocolDecoder::new(),
    86             decoder: ProtocolDecoder::new(),
    89             buf_out: netbuf::Buf::new()
    87             buf_out: netbuf::Buf::new(),
    90         }
    88         }
    91     }
    89     }
    92 
    90 
    93     #[cfg(feature = "tls-connections")]
    91     #[cfg(feature = "tls-connections")]
    94     fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> {
    92     fn handshake_impl(
       
    93         &mut self,
       
    94         handshake: MidHandshakeSslStream<TcpStream>,
       
    95     ) -> io::Result<NetworkClientState> {
    95         match handshake.handshake() {
    96         match handshake.handshake() {
    96             Ok(stream) => {
    97             Ok(stream) => {
    97                 self.socket = ClientSocket::SslStream(stream);
    98                 self.socket = ClientSocket::SslStream(stream);
    98                 debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
    99                 debug!(
       
   100                     "TLS handshake with {} ({}) completed",
       
   101                     self.id, self.peer_addr
       
   102                 );
    99                 Ok(NetworkClientState::Idle)
   103                 Ok(NetworkClientState::Idle)
   100             }
   104             }
   101             Err(HandshakeError::WouldBlock(new_handshake)) => {
   105             Err(HandshakeError::WouldBlock(new_handshake)) => {
   102                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   106                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   103                 Ok(NetworkClientState::Idle)
   107                 Ok(NetworkClientState::Idle)
   105             Err(HandshakeError::Failure(new_handshake)) => {
   109             Err(HandshakeError::Failure(new_handshake)) => {
   106                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   110                 self.socket = ClientSocket::SslHandshake(Some(new_handshake));
   107                 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
   111                 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
   108                 Err(Error::new(ErrorKind::Other, "Connection failure"))
   112                 Err(Error::new(ErrorKind::Other, "Connection failure"))
   109             }
   113             }
   110             Err(HandshakeError::SetupFailure(_)) => unreachable!()
   114             Err(HandshakeError::SetupFailure(_)) => unreachable!(),
   111         }
   115         }
   112     }
   116     }
   113 
   117 
   114     fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
   118     fn read_impl<R: Read>(
   115                           id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
   119         decoder: &mut ProtocolDecoder,
       
   120         source: &mut R,
       
   121         id: ClientId,
       
   122         addr: &SocketAddr,
       
   123     ) -> NetworkResult<Vec<HWProtocolMessage>> {
   116         let mut bytes_read = 0;
   124         let mut bytes_read = 0;
   117         let result = loop {
   125         let result = loop {
   118             match decoder.read_from(source) {
   126             match decoder.read_from(source) {
   119                 Ok(bytes) => {
   127                 Ok(bytes) => {
   120                     debug!("Client {}: read {} bytes", id, bytes);
   128                     debug!("Client {}: read {} bytes", id, bytes);
   125                             (Vec::new(), NetworkClientState::Closed)
   133                             (Vec::new(), NetworkClientState::Closed)
   126                         } else {
   134                         } else {
   127                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
   135                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
   128                         };
   136                         };
   129                         break Ok(result);
   137                         break Ok(result);
       
   138                     } else if bytes_read >= MAX_BYTES_PER_READ {
       
   139                         break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead));
   130                     }
   140                     }
   131                     else if bytes_read >= MAX_BYTES_PER_READ {
       
   132                         break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
       
   133                     }
       
   134                 }
   141                 }
   135                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
   142                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
   136                     let messages =  if bytes_read == 0 {
   143                     let messages = if bytes_read == 0 {
   137                         Vec::new()
   144                         Vec::new()
   138                     } else {
   145                     } else {
   139                         decoder.extract_messages()
   146                         decoder.extract_messages()
   140                     };
   147                     };
   141                     break Ok((messages, NetworkClientState::Idle));
   148                     break Ok((messages, NetworkClientState::Idle));
   142                 }
   149                 }
   143                 Err(error) =>
   150                 Err(error) => break Err(error),
   144                     break Err(error)
       
   145             }
   151             }
   146         };
   152         };
   147         decoder.sweep();
   153         decoder.sweep();
   148         result
   154         result
   149     }
   155     }
   150 
   156 
   151     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
   157     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
   152         #[cfg(not(feature = "tls-connections"))]
   158         #[cfg(not(feature = "tls-connections"))]
   153         match self.socket {
   159         match self.socket {
   154             ClientSocket::Plain(ref mut stream) =>
   160             ClientSocket::Plain(ref mut stream) => {
   155                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
   161                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
       
   162             }
   156         }
   163         }
   157 
   164 
   158         #[cfg(feature = "tls-connections")]
   165         #[cfg(feature = "tls-connections")]
   159         match self.socket {
   166         match self.socket {
   160             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   167             ClientSocket::SslHandshake(ref mut handshake_opt) => {
   161                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   168                 let handshake = std::mem::replace(handshake_opt, None).unwrap();
   162                 Ok((Vec::new(), self.handshake_impl(handshake)?))
   169                 Ok((Vec::new(), self.handshake_impl(handshake)?))
   163             },
   170             }
   164             ClientSocket::SslStream(ref mut stream) =>
   171             ClientSocket::SslStream(ref mut stream) => {
   165                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
   172                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
       
   173             }
   166         }
   174         }
   167     }
   175     }
   168 
   176 
   169     fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
   177     fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
   170         let result = loop {
   178         let result = loop {
   171             match buf_out.write_to(destination) {
   179             match buf_out.write_to(destination) {
   172                 Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
   180                 Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
   173                     break Ok(((), NetworkClientState::Idle)),
   181                     break Ok(((), NetworkClientState::Idle))
       
   182                 }
   174                 Ok(_) => (),
   183                 Ok(_) => (),
   175                 Err(ref error) if error.kind() == ErrorKind::Interrupted
   184                 Err(ref error)
   176                     || error.kind() == ErrorKind::WouldBlock => {
   185                     if error.kind() == ErrorKind::Interrupted
       
   186                         || error.kind() == ErrorKind::WouldBlock =>
       
   187                 {
   177                     break Ok(((), NetworkClientState::NeedsWrite));
   188                     break Ok(((), NetworkClientState::NeedsWrite));
   178                 },
   189                 }
   179                 Err(error) =>
   190                 Err(error) => break Err(error),
   180                     break Err(error)
       
   181             }
   191             }
   182         };
   192         };
   183         result
   193         result
   184     }
   194     }
   185 
   195 
   186     pub fn write(&mut self) -> NetworkResult<()> {
   196     pub fn write(&mut self) -> NetworkResult<()> {
   187         let result = {
   197         let result = {
   188             #[cfg(not(feature = "tls-connections"))]
   198             #[cfg(not(feature = "tls-connections"))]
   189             match self.socket {
   199             match self.socket {
   190                 ClientSocket::Plain(ref mut stream) =>
   200                 ClientSocket::Plain(ref mut stream) => {
   191                     NetworkClient::write_impl(&mut self.buf_out, stream)
   201                     NetworkClient::write_impl(&mut self.buf_out, stream)
   192             }
   202                 }
   193 
   203             }
   194             #[cfg(feature = "tls-connections")] {
   204 
       
   205             #[cfg(feature = "tls-connections")]
       
   206             {
   195                 match self.socket {
   207                 match self.socket {
   196                     ClientSocket::SslHandshake(ref mut handshake_opt) => {
   208                     ClientSocket::SslHandshake(ref mut handshake_opt) => {
   197                         let handshake = std::mem::replace(handshake_opt, None).unwrap();
   209                         let handshake = std::mem::replace(handshake_opt, None).unwrap();
   198                         Ok(((), self.handshake_impl(handshake)?))
   210                         Ok(((), self.handshake_impl(handshake)?))
   199                     }
   211                     }
   200                     ClientSocket::SslStream(ref mut stream) =>
   212                     ClientSocket::SslStream(ref mut stream) => {
   201                         NetworkClient::write_impl(&mut self.buf_out, stream)
   213                         NetworkClient::write_impl(&mut self.buf_out, stream)
       
   214                     }
   202                 }
   215                 }
   203             }
   216             }
   204         };
   217         };
   205 
   218 
   206         self.socket.inner().flush()?;
   219         self.socket.inner().flush()?;
   220     }
   233     }
   221 }
   234 }
   222 
   235 
   223 #[cfg(feature = "tls-connections")]
   236 #[cfg(feature = "tls-connections")]
   224 struct ServerSsl {
   237 struct ServerSsl {
   225     context: SslContext
   238     context: SslContext,
   226 }
   239 }
   227 
   240 
   228 pub struct NetworkLayer {
   241 pub struct NetworkLayer {
   229     listener: TcpListener,
   242     listener: TcpListener,
   230     server: HWServer,
   243     server: HWServer,
   231     clients: Slab<NetworkClient>,
   244     clients: Slab<NetworkClient>,
   232     pending: HashSet<(ClientId, NetworkClientState)>,
   245     pending: HashSet<(ClientId, NetworkClientState)>,
   233     pending_cache: Vec<(ClientId, NetworkClientState)>,
   246     pending_cache: Vec<(ClientId, NetworkClientState)>,
   234     #[cfg(feature = "tls-connections")]
   247     #[cfg(feature = "tls-connections")]
   235     ssl: ServerSsl
   248     ssl: ServerSsl,
   236 }
   249 }
   237 
   250 
   238 impl NetworkLayer {
   251 impl NetworkLayer {
   239     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
   252     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
   240         let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new()));
   253         let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new()));
   241         let clients = Slab::with_capacity(clients_limit);
   254         let clients = Slab::with_capacity(clients_limit);
   242         let pending = HashSet::with_capacity(2 * clients_limit);
   255         let pending = HashSet::with_capacity(2 * clients_limit);
   243         let pending_cache = Vec::with_capacity(2 * clients_limit);
   256         let pending_cache = Vec::with_capacity(2 * clients_limit);
   244 
   257 
   245         NetworkLayer {
   258         NetworkLayer {
   246             listener, server, clients, pending, pending_cache,
   259             listener,
       
   260             server,
       
   261             clients,
       
   262             pending,
       
   263             pending_cache,
   247             #[cfg(feature = "tls-connections")]
   264             #[cfg(feature = "tls-connections")]
   248             ssl: NetworkLayer::create_ssl_context()
   265             ssl: NetworkLayer::create_ssl_context(),
   249         }
   266         }
   250     }
   267     }
   251 
   268 
   252     #[cfg(feature = "tls-connections")]
   269     #[cfg(feature = "tls-connections")]
   253     fn create_ssl_context() -> ServerSsl {
   270     fn create_ssl_context() -> ServerSsl {
   254         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
   271         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
   255         builder.set_verify(SslVerifyMode::NONE);
   272         builder.set_verify(SslVerifyMode::NONE);
   256         builder.set_read_ahead(true);
   273         builder.set_read_ahead(true);
   257         builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
   274         builder
   258         builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
   275             .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
       
   276             .unwrap();
       
   277         builder
       
   278             .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
       
   279             .unwrap();
   259         builder.set_options(SslOptions::NO_COMPRESSION);
   280         builder.set_options(SslOptions::NO_COMPRESSION);
   260         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
   281         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
   261         ServerSsl { context: builder.build() }
   282         ServerSsl {
       
   283             context: builder.build(),
       
   284         }
   262     }
   285     }
   263 
   286 
   264     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
   287     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
   265         poll.register(&self.listener, utils::SERVER, Ready::readable(),
   288         poll.register(
   266                       PollOpt::edge())
   289             &self.listener,
       
   290             utils::SERVER,
       
   291             Ready::readable(),
       
   292             PollOpt::edge(),
       
   293         )
   267     }
   294     }
   268 
   295 
   269     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
   296     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
   270         let mut client_exists = false;
   297         let mut client_exists = false;
   271         if let Some(ref client) = self.clients.get(id) {
   298         if let Some(ref client) = self.clients.get(id) {
   277         if client_exists {
   304         if client_exists {
   278             self.clients.remove(id);
   305             self.clients.remove(id);
   279         }
   306         }
   280     }
   307     }
   281 
   308 
   282     fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
   309     fn register_client(
   283         poll.register(client_socket.inner(), Token(id),
   310         &mut self,
   284                       Ready::readable() | Ready::writable(),
   311         poll: &Poll,
   285                       PollOpt::edge())
   312         id: ClientId,
   286             .expect("could not register socket with event loop");
   313         client_socket: ClientSocket,
       
   314         addr: SocketAddr,
       
   315     ) {
       
   316         poll.register(
       
   317             client_socket.inner(),
       
   318             Token(id),
       
   319             Ready::readable() | Ready::writable(),
       
   320             PollOpt::edge(),
       
   321         )
       
   322         .expect("could not register socket with event loop");
   287 
   323 
   288         let entry = self.clients.vacant_entry();
   324         let entry = self.clients.vacant_entry();
   289         let client = NetworkClient::new(id, client_socket, addr);
   325         let client = NetworkClient::new(id, client_socket, addr);
   290         info!("client {} ({}) added", client.id, client.peer_addr);
   326         info!("client {} ({}) added", client.id, client.peer_addr);
   291         entry.insert(client);
   327         entry.insert(client);
   297             debug!("Message {:?} to {:?}", message, clients);
   333             debug!("Message {:?} to {:?}", message, clients);
   298             let msg_string = message.to_raw_protocol();
   334             let msg_string = message.to_raw_protocol();
   299             for client_id in clients {
   335             for client_id in clients {
   300                 if let Some(client) = self.clients.get_mut(client_id) {
   336                 if let Some(client) = self.clients.get_mut(client_id) {
   301                     client.send_string(&msg_string);
   337                     client.send_string(&msg_string);
   302                     self.pending.insert((client_id, NetworkClientState::NeedsWrite));
   338                     self.pending
       
   339                         .insert((client_id, NetworkClientState::NeedsWrite));
   303                 }
   340                 }
   304             }
   341             }
   305         }
   342         }
   306     }
   343     }
   307 
   344 
   308     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
   345     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
   309         #[cfg(not(feature = "tls-connections"))] {
   346         #[cfg(not(feature = "tls-connections"))]
       
   347         {
   310             Ok(ClientSocket::Plain(socket))
   348             Ok(ClientSocket::Plain(socket))
   311         }
   349         }
   312 
   350 
   313         #[cfg(feature = "tls-connections")] {
   351         #[cfg(feature = "tls-connections")]
       
   352         {
   314             let ssl = Ssl::new(&self.ssl.context).unwrap();
   353             let ssl = Ssl::new(&self.ssl.context).unwrap();
   315             let mut builder = SslStreamBuilder::new(ssl, socket);
   354             let mut builder = SslStreamBuilder::new(ssl, socket);
   316             builder.set_accept_state();
   355             builder.set_accept_state();
   317             match builder.handshake() {
   356             match builder.handshake() {
   318                 Ok(stream) =>
   357                 Ok(stream) => Ok(ClientSocket::SslStream(stream)),
   319                     Ok(ClientSocket::SslStream(stream)),
   358                 Err(HandshakeError::WouldBlock(stream)) => {
   320                 Err(HandshakeError::WouldBlock(stream)) =>
   359                     Ok(ClientSocket::SslHandshake(Some(stream)))
   321                     Ok(ClientSocket::SslHandshake(Some(stream))),
   360                 }
   322                 Err(e) => {
   361                 Err(e) => {
   323                     debug!("OpenSSL handshake failed: {}", e);
   362                     debug!("OpenSSL handshake failed: {}", e);
   324                     Err(Error::new(ErrorKind::Other, "Connection failure"))
   363                     Err(Error::new(ErrorKind::Other, "Connection failure"))
   325                 }
   364                 }
   326             }
   365             }
   330     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
   369     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
   331         let (client_socket, addr) = self.listener.accept()?;
   370         let (client_socket, addr) = self.listener.accept()?;
   332         info!("Connected: {}", addr);
   371         info!("Connected: {}", addr);
   333 
   372 
   334         let client_id = self.server.add_client();
   373         let client_id = self.server.add_client();
   335         self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
   374         self.register_client(
       
   375             poll,
       
   376             client_id,
       
   377             self.create_client_socket(client_socket)?,
       
   378             addr,
       
   379         );
   336         self.flush_server_messages();
   380         self.flush_server_messages();
   337 
   381 
   338         Ok(())
   382         Ok(())
   339     }
   383     }
   340 
   384 
   341     fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> {
   385     fn operation_failed(
       
   386         &mut self,
       
   387         poll: &Poll,
       
   388         client_id: ClientId,
       
   389         error: &Error,
       
   390         msg: &str,
       
   391     ) -> io::Result<()> {
   342         let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
   392         let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
   343             client.peer_addr
   393             client.peer_addr
   344         } else {
   394         } else {
   345             SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
   395             SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
   346         };
   396         };
   347         debug!("{}({}): {}", msg, addr, error);
   397         debug!("{}({}): {}", msg, addr, error);
   348         self.client_error(poll, client_id)
   398         self.client_error(poll, client_id)
   349     }
   399     }
   350 
   400 
   351     pub fn client_readable(&mut self, poll: &Poll,
   401     pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   352                            client_id: ClientId) -> io::Result<()> {
   402         let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
   353         let messages =
   403             client.read()
   354             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   404         } else {
   355                 client.read()
   405             warn!("invalid readable client: {}", client_id);
   356             } else {
   406             Ok((Vec::new(), NetworkClientState::Idle))
   357                 warn!("invalid readable client: {}", client_id);
   407         };
   358                 Ok((Vec::new(), NetworkClientState::Idle))
       
   359             };
       
   360 
   408 
   361         match messages {
   409         match messages {
   362             Ok((messages, state)) => {
   410             Ok((messages, state)) => {
   363                 for message in messages {
   411                 for message in messages {
   364                     self.server.handle_msg(client_id, message);
   412                     self.server.handle_msg(client_id, message);
   365                 }
   413                 }
   366                 match state {
   414                 match state {
   367                     NetworkClientState::NeedsRead => {
   415                     NetworkClientState::NeedsRead => {
   368                         self.pending.insert((client_id, state));
   416                         self.pending.insert((client_id, state));
   369                     },
   417                     }
   370                     NetworkClientState::Closed =>
   418                     NetworkClientState::Closed => self.client_error(&poll, client_id)?,
   371                         self.client_error(&poll, client_id)?,
       
   372                     _ => {}
   419                     _ => {}
   373                 };
   420                 };
   374             }
   421             }
   375             Err(e) => self.operation_failed(
   422             Err(e) => self.operation_failed(
   376                 poll, client_id, &e,
   423                 poll,
   377                 "Error while reading from client socket")?
   424                 client_id,
       
   425                 &e,
       
   426                 "Error while reading from client socket",
       
   427             )?,
   378         }
   428         }
   379 
   429 
   380         self.flush_server_messages();
   430         self.flush_server_messages();
   381 
   431 
   382         if !self.server.removed_clients.is_empty() {
   432         if !self.server.removed_clients.is_empty() {
   387         }
   437         }
   388 
   438 
   389         Ok(())
   439         Ok(())
   390     }
   440     }
   391 
   441 
   392     pub fn client_writable(&mut self, poll: &Poll,
   442     pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   393                            client_id: ClientId) -> io::Result<()> {
   443         let result = if let Some(ref mut client) = self.clients.get_mut(client_id) {
   394         let result =
   444             client.write()
   395             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   445         } else {
   396                 client.write()
   446             warn!("invalid writable client: {}", client_id);
   397             } else {
   447             Ok(((), NetworkClientState::Idle))
   398                 warn!("invalid writable client: {}", client_id);
   448         };
   399                 Ok(((), NetworkClientState::Idle))
       
   400             };
       
   401 
   449 
   402         match result {
   450         match result {
   403             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
   451             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
   404                 self.pending.insert((client_id, state));
   452                 self.pending.insert((client_id, state));
   405             },
   453             }
   406             Ok(_) => {}
   454             Ok(_) => {}
   407             Err(e) => self.operation_failed(
   455             Err(e) => {
   408                 poll, client_id, &e,
   456                 self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
   409                 "Error while writing to client socket")?
   457             }
   410         }
   458         }
   411 
   459 
   412         Ok(())
   460         Ok(())
   413     }
   461     }
   414 
   462 
   415     pub fn client_error(&mut self, poll: &Poll,
   463     pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
   416                         client_id: ClientId) -> io::Result<()> {
       
   417         self.deregister_client(poll, client_id);
   464         self.deregister_client(poll, client_id);
   418         self.server.client_lost(client_id);
   465         self.server.client_lost(client_id);
   419 
   466 
   420         Ok(())
   467         Ok(())
   421     }
   468     }
   428         if self.has_pending_operations() {
   475         if self.has_pending_operations() {
   429             let mut cache = replace(&mut self.pending_cache, Vec::new());
   476             let mut cache = replace(&mut self.pending_cache, Vec::new());
   430             cache.extend(self.pending.drain());
   477             cache.extend(self.pending.drain());
   431             for (id, state) in cache.drain(..) {
   478             for (id, state) in cache.drain(..) {
   432                 match state {
   479                 match state {
   433                     NetworkClientState::NeedsRead =>
   480                     NetworkClientState::NeedsRead => self.client_readable(poll, id)?,
   434                         self.client_readable(poll, id)?,
   481                     NetworkClientState::NeedsWrite => self.client_writable(poll, id)?,
   435                     NetworkClientState::NeedsWrite =>
       
   436                         self.client_writable(poll, id)?,
       
   437                     _ => {}
   482                     _ => {}
   438                 }
   483                 }
   439             }
   484             }
   440             swap(&mut cache, &mut self.pending_cache);
   485             swap(&mut cache, &mut self.pending_cache);
   441         }
   486         }