gameServer2/src/server/network.rs
changeset 13799 c8fd12db6215
parent 13666 09f4a30e50cc
child 13802 24fe56d3c6a2
equal deleted inserted replaced
13798:4664da990556 13799:c8fd12db6215
     1 extern crate slab;
     1 extern crate slab;
     2 
     2 
     3 use std::{
     3 use std::{
     4     io, io::{Error, ErrorKind, Write},
     4     io, io::{Error, ErrorKind, Read, Write},
     5     net::{SocketAddr, IpAddr, Ipv4Addr},
     5     net::{SocketAddr, IpAddr, Ipv4Addr},
     6     collections::HashSet,
     6     collections::HashSet,
     7     mem::{swap, replace}
     7     mem::{swap, replace}
     8 };
     8 };
     9 
     9 
    20 };
    20 };
    21 use super::{
    21 use super::{
    22     server::{HWServer},
    22     server::{HWServer},
    23     coretypes::ClientId
    23     coretypes::ClientId
    24 };
    24 };
       
    25 #[cfg(feature = "tls-connections")]
       
    26 use openssl::{
       
    27     ssl::{
       
    28         SslMethod, SslContext, Ssl, SslContextBuilder,
       
    29         SslVerifyMode, SslFiletype, SslOptions,
       
    30         SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
       
    31     },
       
    32     error::ErrorStack
       
    33 };
    25 
    34 
    26 const MAX_BYTES_PER_READ: usize = 2048;
    35 const MAX_BYTES_PER_READ: usize = 2048;
    27 
    36 
    28 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
    37 #[derive(Hash, Eq, PartialEq, Copy, Clone)]
    29 pub enum NetworkClientState {
    38 pub enum NetworkClientState {
    33     Closed,
    42     Closed,
    34 }
    43 }
    35 
    44 
    36 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    45 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
    37 
    46 
       
    47 #[cfg(not(feature = "tls-connections"))]
       
    48 pub enum ClientSocket {
       
    49     Plain(TcpStream)
       
    50 }
       
    51 
       
    52 #[cfg(feature = "tls-connections")]
       
    53 pub enum ClientSocket {
       
    54     SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
       
    55     SslStream(SslStream<TcpStream>)
       
    56 }
       
    57 
       
    58 impl ClientSocket {
       
    59     fn inner(&self) -> &TcpStream {
       
    60         #[cfg(not(feature = "tls-connections"))]
       
    61         match self {
       
    62             ClientSocket::Plain(stream) => stream,
       
    63         }
       
    64 
       
    65         #[cfg(feature = "tls-connections")]
       
    66         match self {
       
    67             ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
       
    68             ClientSocket::SslHandshake(None) => unreachable!(),
       
    69             ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
       
    70         }
       
    71     }
       
    72 }
       
    73 
    38 pub struct NetworkClient {
    74 pub struct NetworkClient {
    39     id: ClientId,
    75     id: ClientId,
    40     socket: TcpStream,
    76     socket: ClientSocket,
    41     peer_addr: SocketAddr,
    77     peer_addr: SocketAddr,
    42     decoder: ProtocolDecoder,
    78     decoder: ProtocolDecoder,
    43     buf_out: netbuf::Buf
    79     buf_out: netbuf::Buf
    44 }
    80 }
    45 
    81 
    46 impl NetworkClient {
    82 impl NetworkClient {
    47     pub fn new(id: ClientId, socket: TcpStream, peer_addr: SocketAddr) -> NetworkClient {
    83     pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
    48         NetworkClient {
    84         NetworkClient {
    49             id, socket, peer_addr,
    85             id, socket, peer_addr,
    50             decoder: ProtocolDecoder::new(),
    86             decoder: ProtocolDecoder::new(),
    51             buf_out: netbuf::Buf::new()
    87             buf_out: netbuf::Buf::new()
    52         }
    88         }
    53     }
    89     }
    54 
    90 
    55     pub fn read_messages(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
    91     fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
       
    92                           id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
    56         let mut bytes_read = 0;
    93         let mut bytes_read = 0;
    57         let result = loop {
    94         let result = loop {
    58             match self.decoder.read_from(&mut self.socket) {
    95             match decoder.read_from(source) {
    59                 Ok(bytes) => {
    96                 Ok(bytes) => {
    60                     debug!("Client {}: read {} bytes", self.id, bytes);
    97                     debug!("Client {}: read {} bytes", id, bytes);
    61                     bytes_read += bytes;
    98                     bytes_read += bytes;
    62                     if bytes == 0 {
    99                     if bytes == 0 {
    63                         let result = if bytes_read == 0 {
   100                         let result = if bytes_read == 0 {
    64                             info!("EOF for client {} ({})", self.id, self.peer_addr);
   101                             info!("EOF for client {} ({})", id, addr);
    65                             (Vec::new(), NetworkClientState::Closed)
   102                             (Vec::new(), NetworkClientState::Closed)
    66                         } else {
   103                         } else {
    67                             (self.decoder.extract_messages(), NetworkClientState::NeedsRead)
   104                             (decoder.extract_messages(), NetworkClientState::NeedsRead)
    68                         };
   105                         };
    69                         break Ok(result);
   106                         break Ok(result);
    70                     }
   107                     }
    71                     else if bytes_read >= MAX_BYTES_PER_READ {
   108                     else if bytes_read >= MAX_BYTES_PER_READ {
    72                         break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead))
   109                         break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
    73                     }
   110                     }
    74                 }
   111                 }
    75                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
   112                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
    76                     let messages =  if bytes_read == 0 {
   113                     let messages =  if bytes_read == 0 {
    77                         Vec::new()
   114                         Vec::new()
    78                     } else {
   115                     } else {
    79                         self.decoder.extract_messages()
   116                         decoder.extract_messages()
    80                     };
   117                     };
    81                     break Ok((messages, NetworkClientState::Idle));
   118                     break Ok((messages, NetworkClientState::Idle));
    82                 }
   119                 }
    83                 Err(error) =>
   120                 Err(error) =>
    84                     break Err(error)
   121                     break Err(error)
    85             }
   122             }
    86         };
   123         };
    87         self.decoder.sweep();
   124         decoder.sweep();
    88         result
   125         result
    89     }
   126     }
    90 
   127 
    91     pub fn flush(&mut self) -> NetworkResult<()> {
   128     pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
       
   129         #[cfg(not(feature = "tls-connections"))]
       
   130         match self.socket {
       
   131             ClientSocket::Plain(ref mut stream) =>
       
   132                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
       
   133         }
       
   134 
       
   135         #[cfg(feature = "tls-connections")]
       
   136         match self.socket {
       
   137             ClientSocket::SslHandshake(ref mut handshake_opt) => {
       
   138                 let mut handshake = std::mem::replace(handshake_opt, None).unwrap();
       
   139 
       
   140                 match handshake.handshake() {
       
   141                     Ok(stream) => {
       
   142                         debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
       
   143                         self.socket = ClientSocket::SslStream(stream);
       
   144 
       
   145                         Ok((Vec::new(), NetworkClientState::Idle))
       
   146                     }
       
   147                     Err(HandshakeError::WouldBlock(new_handshake)) => {
       
   148                         *handshake_opt = Some(new_handshake);
       
   149                         Ok((Vec::new(), NetworkClientState::Idle))
       
   150                     }
       
   151                     Err(e) => {
       
   152                         debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
       
   153                         Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   154                     }
       
   155                 }
       
   156             },
       
   157             ClientSocket::SslStream(ref mut stream) =>
       
   158                 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
       
   159         }
       
   160     }
       
   161 
       
   162     fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
    92         let result = loop {
   163         let result = loop {
    93             match self.buf_out.write_to(&mut self.socket) {
   164             match buf_out.write_to(destination) {
    94                 Ok(bytes) if self.buf_out.is_empty() || bytes == 0 =>
   165                 Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
    95                     break Ok(((), NetworkClientState::Idle)),
   166                     break Ok(((), NetworkClientState::Idle)),
    96                 Ok(_) => (),
   167                 Ok(_) => (),
    97                 Err(ref error) if error.kind() == ErrorKind::Interrupted
   168                 Err(ref error) if error.kind() == ErrorKind::Interrupted
    98                     || error.kind() == ErrorKind::WouldBlock => {
   169                     || error.kind() == ErrorKind::WouldBlock => {
    99                     break Ok(((), NetworkClientState::NeedsWrite));
   170                     break Ok(((), NetworkClientState::NeedsWrite));
   100                 },
   171                 },
   101                 Err(error) =>
   172                 Err(error) =>
   102                     break Err(error)
   173                     break Err(error)
   103             }
   174             }
   104         };
   175         };
   105         self.socket.flush()?;
   176         result
       
   177     }
       
   178 
       
   179     pub fn write(&mut self) -> NetworkResult<()> {
       
   180         let result = {
       
   181             #[cfg(not(feature = "tls-connections"))]
       
   182             match self.socket {
       
   183                 ClientSocket::Plain(ref mut stream) =>
       
   184                     NetworkClient::write_impl(&mut self.buf_out, stream)
       
   185             }
       
   186 
       
   187             #[cfg(feature = "tls-connections")] {
       
   188                 match self.socket {
       
   189                     ClientSocket::SslHandshake(_) =>
       
   190                         Ok(((), NetworkClientState::Idle)),
       
   191                     ClientSocket::SslStream(ref mut stream) =>
       
   192                         NetworkClient::write_impl(&mut self.buf_out, stream)
       
   193                 }
       
   194             }
       
   195         };
       
   196 
       
   197         self.socket.inner().flush()?;
   106         result
   198         result
   107     }
   199     }
   108 
   200 
   109     pub fn send_raw_msg(&mut self, msg: &[u8]) {
   201     pub fn send_raw_msg(&mut self, msg: &[u8]) {
   110         self.buf_out.write_all(msg).unwrap();
   202         self.buf_out.write_all(msg).unwrap();
   115     }
   207     }
   116 
   208 
   117     pub fn send_msg(&mut self, msg: &HWServerMessage) {
   209     pub fn send_msg(&mut self, msg: &HWServerMessage) {
   118         self.send_string(&msg.to_raw_protocol());
   210         self.send_string(&msg.to_raw_protocol());
   119     }
   211     }
       
   212 }
       
   213 
       
   214 #[cfg(feature = "tls-connections")]
       
   215 struct ServerSsl {
       
   216     context: SslContext
   120 }
   217 }
   121 
   218 
   122 pub struct NetworkLayer {
   219 pub struct NetworkLayer {
   123     listener: TcpListener,
   220     listener: TcpListener,
   124     server: HWServer,
   221     server: HWServer,
   125     clients: Slab<NetworkClient>,
   222     clients: Slab<NetworkClient>,
   126     pending: HashSet<(ClientId, NetworkClientState)>,
   223     pending: HashSet<(ClientId, NetworkClientState)>,
   127     pending_cache: Vec<(ClientId, NetworkClientState)>
   224     pending_cache: Vec<(ClientId, NetworkClientState)>,
       
   225     #[cfg(feature = "tls-connections")]
       
   226     ssl: ServerSsl
   128 }
   227 }
   129 
   228 
   130 impl NetworkLayer {
   229 impl NetworkLayer {
   131     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
   230     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
   132         let server = HWServer::new(clients_limit, rooms_limit);
   231         let server = HWServer::new(clients_limit, rooms_limit);
   133         let clients = Slab::with_capacity(clients_limit);
   232         let clients = Slab::with_capacity(clients_limit);
   134         let pending = HashSet::with_capacity(2 * clients_limit);
   233         let pending = HashSet::with_capacity(2 * clients_limit);
   135         let pending_cache = Vec::with_capacity(2 * clients_limit);
   234         let pending_cache = Vec::with_capacity(2 * clients_limit);
   136         NetworkLayer {listener, server, clients, pending, pending_cache}
   235 
       
   236         NetworkLayer {
       
   237             listener, server, clients, pending, pending_cache,
       
   238             #[cfg(feature = "tls-connections")]
       
   239             ssl: NetworkLayer::create_ssl_context()
       
   240         }
       
   241     }
       
   242 
       
   243     #[cfg(feature = "tls-connections")]
       
   244     fn create_ssl_context() -> ServerSsl {
       
   245         let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
       
   246         builder.set_verify(SslVerifyMode::NONE);
       
   247         builder.set_read_ahead(true);
       
   248         builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
       
   249         builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
       
   250         builder.set_options(SslOptions::NO_COMPRESSION);
       
   251         builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
       
   252         ServerSsl { context: builder.build() }
   137     }
   253     }
   138 
   254 
   139     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
   255     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
   140         poll.register(&self.listener, utils::SERVER, Ready::readable(),
   256         poll.register(&self.listener, utils::SERVER, Ready::readable(),
   141                       PollOpt::edge())
   257                       PollOpt::edge())
   142     }
   258     }
   143 
   259 
   144     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
   260     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
   145         let mut client_exists = false;
   261         let mut client_exists = false;
   146         if let Some(ref client) = self.clients.get(id) {
   262         if let Some(ref client) = self.clients.get(id) {
   147             poll.deregister(&client.socket)
   263             poll.deregister(client.socket.inner())
   148                 .expect("could not deregister socket");
   264                 .expect("could not deregister socket");
   149             info!("client {} ({}) removed", client.id, client.peer_addr);
   265             info!("client {} ({}) removed", client.id, client.peer_addr);
   150             client_exists = true;
   266             client_exists = true;
   151         }
   267         }
   152         if client_exists {
   268         if client_exists {
   153             self.clients.remove(id);
   269             self.clients.remove(id);
   154         }
   270         }
   155     }
   271     }
   156 
   272 
   157     fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: TcpStream, addr: SocketAddr) {
   273     fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
   158         poll.register(&client_socket, Token(id),
   274         poll.register(client_socket.inner(), Token(id),
   159                       Ready::readable() | Ready::writable(),
   275                       Ready::readable() | Ready::writable(),
   160                       PollOpt::edge())
   276                       PollOpt::edge())
   161             .expect("could not register socket with event loop");
   277             .expect("could not register socket with event loop");
   162 
   278 
   163         let entry = self.clients.vacant_entry();
   279         let entry = self.clients.vacant_entry();
   178                 }
   294                 }
   179             }
   295             }
   180         }
   296         }
   181     }
   297     }
   182 
   298 
       
   299     fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
       
   300         #[cfg(not(feature = "tls-connections"))] {
       
   301             Ok(ClientSocket::Plain(socket))
       
   302         }
       
   303 
       
   304         #[cfg(feature = "tls-connections")] {
       
   305             let ssl = Ssl::new(&self.ssl.context).unwrap();
       
   306             let mut builder = SslStreamBuilder::new(ssl, socket);
       
   307             builder.set_accept_state();
       
   308             match builder.handshake() {
       
   309                 Ok(stream) =>
       
   310                     Ok(ClientSocket::SslStream(stream)),
       
   311                 Err(HandshakeError::WouldBlock(stream)) =>
       
   312                     Ok(ClientSocket::SslHandshake(Some(stream))),
       
   313                 Err(e) => {
       
   314                     debug!("OpenSSL handshake failed: {}", e);
       
   315                     Err(Error::new(ErrorKind::Other, "Connection failure"))
       
   316                 }
       
   317             }
       
   318         }
       
   319     }
       
   320 
   183     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
   321     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
   184         let (client_socket, addr) = self.listener.accept()?;
   322         let (client_socket, addr) = self.listener.accept()?;
   185         info!("Connected: {}", addr);
   323         info!("Connected: {}", addr);
   186 
   324 
   187         let client_id = self.server.add_client();
   325         let client_id = self.server.add_client();
   188         self.register_client(poll, client_id, client_socket, addr);
   326         self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
   189         self.flush_server_messages();
   327         self.flush_server_messages();
   190 
   328 
   191         Ok(())
   329         Ok(())
   192     }
   330     }
   193 
   331 
   203 
   341 
   204     pub fn client_readable(&mut self, poll: &Poll,
   342     pub fn client_readable(&mut self, poll: &Poll,
   205                            client_id: ClientId) -> io::Result<()> {
   343                            client_id: ClientId) -> io::Result<()> {
   206         let messages =
   344         let messages =
   207             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   345             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   208                 client.read_messages()
   346                 client.read()
   209             } else {
   347             } else {
   210                 warn!("invalid readable client: {}", client_id);
   348                 warn!("invalid readable client: {}", client_id);
   211                 Ok((Vec::new(), NetworkClientState::Idle))
   349                 Ok((Vec::new(), NetworkClientState::Idle))
   212             };
   350             };
   213 
   351 
   244 
   382 
   245     pub fn client_writable(&mut self, poll: &Poll,
   383     pub fn client_writable(&mut self, poll: &Poll,
   246                            client_id: ClientId) -> io::Result<()> {
   384                            client_id: ClientId) -> io::Result<()> {
   247         let result =
   385         let result =
   248             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   386             if let Some(ref mut client) = self.clients.get_mut(client_id) {
   249                 client.flush()
   387                 client.write()
   250             } else {
   388             } else {
   251                 warn!("invalid writable client: {}", client_id);
   389                 warn!("invalid writable client: {}", client_id);
   252                 Ok(((), NetworkClientState::Idle))
   390                 Ok(((), NetworkClientState::Idle))
   253             };
   391             };
   254 
   392