rust/hedgewars-server/src/server/network.rs
changeset 15936 c5c53ebb2d91
parent 15832 a4d505a32879
child 15937 e514ceb5e7d6
equal deleted inserted replaced
15935:cd3d16905e0e 15936:c5c53ebb2d91
     1 use bytes::{Buf, Bytes};
     1 use bytes::{Buf, Bytes};
     2 use log::*;
     2 use log::*;
     3 use slab::Slab;
     3 use slab::Slab;
       
     4 use std::io::Error;
       
     5 use std::pin::Pin;
       
     6 use std::task::{Context, Poll};
     4 use std::{
     7 use std::{
     5     iter::Iterator,
     8     iter::Iterator,
     6     net::{IpAddr, SocketAddr},
     9     net::{IpAddr, SocketAddr},
     7     time::Duration,
    10     time::Duration,
     8 };
    11 };
     9 use tokio::{
    12 use tokio::{
    10     io::AsyncReadExt,
    13     io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
    11     net::{TcpListener, TcpStream},
    14     net::{TcpListener, TcpStream},
    12     sync::mpsc::{channel, Receiver, Sender},
    15     sync::mpsc::{channel, Receiver, Sender},
    13 };
    16 };
       
    17 #[cfg(feature = "tls-connections")]
       
    18 use tokio_native_tls::{TlsAcceptor, TlsStream};
    14 
    19 
    15 use crate::{
    20 use crate::{
    16     core::{
    21     core::{
    17         events::{TimedEvents, Timeout},
    22         events::{TimedEvents, Timeout},
    18         types::ClientId,
    23         types::ClientId,
    23     utils,
    28     utils,
    24 };
    29 };
    25 use hedgewars_network_protocol::{
    30 use hedgewars_network_protocol::{
    26     messages::HwServerMessage::Redirect, messages::*, parser::server_message,
    31     messages::HwServerMessage::Redirect, messages::*, parser::server_message,
    27 };
    32 };
    28 use tokio::io::AsyncWriteExt;
       
    29 
    33 
    30 const PING_TIMEOUT: Duration = Duration::from_secs(15);
    34 const PING_TIMEOUT: Duration = Duration::from_secs(15);
    31 
    35 
    32 enum ClientUpdateData {
    36 enum ClientUpdateData {
    33     Message(HwProtocolMessage),
    37     Message(HwProtocolMessage),
    54             .await
    58             .await
    55             .is_ok()
    59             .is_ok()
    56     }
    60     }
    57 }
    61 }
    58 
    62 
       
    63 enum ClientStream {
       
    64     Tcp(TcpStream),
       
    65     #[cfg(feature = "tls-connections")]
       
    66     Tls(TlsStream<TcpStream>),
       
    67 }
       
    68 
       
    69 impl Unpin for ClientStream {}
       
    70 
       
    71 impl AsyncRead for ClientStream {
       
    72     fn poll_read(
       
    73         self: Pin<&mut Self>,
       
    74         cx: &mut Context<'_>,
       
    75         buf: &mut ReadBuf<'_>,
       
    76     ) -> Poll<std::io::Result<()>> {
       
    77         use ClientStream::*;
       
    78         match Pin::into_inner(self) {
       
    79             Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
       
    80             #[cfg(feature = "tls-connections")]
       
    81             Tls(stream) => Pin::new(stream).poll_read(cx, buf),
       
    82         }
       
    83     }
       
    84 }
       
    85 
       
    86 impl AsyncWrite for ClientStream {
       
    87     fn poll_write(
       
    88         self: Pin<&mut Self>,
       
    89         cx: &mut Context<'_>,
       
    90         buf: &[u8],
       
    91     ) -> Poll<Result<usize, Error>> {
       
    92         use ClientStream::*;
       
    93         match Pin::into_inner(self) {
       
    94             Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
       
    95             #[cfg(feature = "tls-connections")]
       
    96             Tls(stream) => Pin::new(stream).poll_write(cx, buf),
       
    97         }
       
    98     }
       
    99 
       
   100     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
       
   101         use ClientStream::*;
       
   102         match Pin::into_inner(self) {
       
   103             Tcp(stream) => Pin::new(stream).poll_flush(cx),
       
   104             #[cfg(feature = "tls-connections")]
       
   105             Tls(stream) => Pin::new(stream).poll_flush(cx),
       
   106         }
       
   107     }
       
   108 
       
   109     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
       
   110         use ClientStream::*;
       
   111         match Pin::into_inner(self) {
       
   112             Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
       
   113             #[cfg(feature = "tls-connections")]
       
   114             Tls(stream) => Pin::new(stream).poll_shutdown(cx),
       
   115         }
       
   116     }
       
   117 }
       
   118 
    59 struct NetworkClient {
   119 struct NetworkClient {
    60     id: ClientId,
   120     id: ClientId,
    61     socket: TcpStream,
   121     stream: ClientStream,
    62     receiver: Receiver<Bytes>,
   122     receiver: Receiver<Bytes>,
    63     peer_addr: SocketAddr,
   123     peer_addr: SocketAddr,
    64     decoder: ProtocolDecoder,
   124     decoder: ProtocolDecoder,
    65 }
   125 }
    66 
   126 
    67 impl NetworkClient {
   127 impl NetworkClient {
    68     fn new(
   128     fn new(
    69         id: ClientId,
   129         id: ClientId,
    70         socket: TcpStream,
   130         stream: ClientStream,
    71         peer_addr: SocketAddr,
   131         peer_addr: SocketAddr,
    72         receiver: Receiver<Bytes>,
   132         receiver: Receiver<Bytes>,
    73     ) -> Self {
   133     ) -> Self {
    74         Self {
   134         Self {
    75             id,
   135             id,
    76             socket,
   136             stream,
    77             peer_addr,
   137             peer_addr,
    78             receiver,
   138             receiver,
    79             decoder: ProtocolDecoder::new(PING_TIMEOUT),
   139             decoder: ProtocolDecoder::new(PING_TIMEOUT),
    80         }
   140         }
    81     }
   141     }
    82 
   142 
    83     async fn read(
   143     async fn read<T: AsyncRead + AsyncWrite + Unpin>(
    84         socket: &mut TcpStream,
   144         stream: &mut T,
    85         decoder: &mut ProtocolDecoder,
   145         decoder: &mut ProtocolDecoder,
    86     ) -> protocol::Result<HwProtocolMessage> {
   146     ) -> protocol::Result<HwProtocolMessage> {
    87         let result = decoder.read_from(socket).await;
   147         let result = decoder.read_from(stream).await;
    88         if matches!(result, Err(ProtocolError::Timeout)) {
   148         if matches!(result, Err(ProtocolError::Timeout)) {
    89             if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
   149             if Self::write(stream, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
    90                 decoder.read_from(socket).await
   150                 decoder.read_from(stream).await
    91             } else {
   151             } else {
    92                 Err(ProtocolError::Eof)
   152                 Err(ProtocolError::Eof)
    93             }
   153             }
    94         } else {
   154         } else {
    95             result
   155             result
    96         }
   156         }
    97     }
   157     }
    98 
   158 
    99     async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool {
   159     async fn write<T: AsyncWrite + Unpin>(stream: &mut T, mut data: Bytes) -> bool {
   100         !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0)
   160         !data.has_remaining() || matches!(stream.write_buf(&mut data).await, Ok(n) if n > 0)
   101     }
   161     }
   102 
   162 
   103     async fn run(mut self, sender: Sender<ClientUpdate>) {
   163     async fn run(mut self, sender: Sender<ClientUpdate>) {
   104         use ClientUpdateData::*;
   164         use ClientUpdateData::*;
   105         let mut sender = ClientUpdateSender {
   165         let mut sender = ClientUpdateSender {
   109 
   169 
   110         loop {
   170         loop {
   111             tokio::select! {
   171             tokio::select! {
   112                 server_message = self.receiver.recv() => {
   172                 server_message = self.receiver.recv() => {
   113                     match server_message {
   173                     match server_message {
   114                         Some(message) => if !Self::write(&mut self.socket, message).await {
   174                         Some(message) => if !Self::write(&mut self.stream, message).await {
   115                             sender.send(Error("Connection reset by peer".to_string())).await;
   175                             sender.send(Error("Connection reset by peer".to_string())).await;
   116                             break;
   176                             break;
   117                         }
   177                         }
   118                         None => {
   178                         None => {
   119                             break;
   179                             break;
   120                         }
   180                         }
   121                     }
   181                     }
   122                 }
   182                 }
   123                 client_message = Self::read(&mut self.socket, &mut self.decoder) => {
   183                 client_message = Self::read(&mut self.stream, &mut self.decoder) => {
   124                      match client_message {
   184                      match client_message {
   125                         Ok(message) => {
   185                         Ok(message) => {
   126                             if !sender.send(Message(message)).await {
   186                             if !sender.send(Message(message)).await {
   127                                 break;
   187                                 break;
   128                             }
   188                             }
   129                         }
   189                         }
   130                         Err(e) => {
   190                         Err(e) => {
   131                             sender.send(Error(format!("{}", e))).await;
   191                             sender.send(Error(format!("{}", e))).await;
   132                             if matches!(e, ProtocolError::Timeout) {
   192                             if matches!(e, ProtocolError::Timeout) {
   133                                 Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
   193                                 Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
   134                             }
   194                             }
   135                             break;
   195                             break;
   136                         }
   196                         }
   137                     }
   197                     }
   138                 }
   198                 }
   139             }
   199             }
   140         }
   200         }
   141     }
   201     }
   142 }
   202 }
   143 
   203 
       
   204 #[cfg(feature = "tls-connections")]
       
   205 struct TlsListener {
       
   206     listener: TcpListener,
       
   207     acceptor: TlsAcceptor,
       
   208 }
       
   209 
   144 pub struct NetworkLayer {
   210 pub struct NetworkLayer {
   145     listener: TcpListener,
   211     listener: TcpListener,
       
   212     #[cfg(feature = "tls-connections")]
       
   213     tls: TlsListener,
   146     server_state: ServerState,
   214     server_state: ServerState,
   147     clients: Slab<Sender<Bytes>>,
   215     clients: Slab<Sender<Bytes>>,
   148 }
   216 }
   149 
   217 
   150 impl NetworkLayer {
   218 impl NetworkLayer {
   151     pub async fn run(&mut self) {
   219     pub async fn run(&mut self) {
   152         let (update_tx, mut update_rx) = channel(128);
   220         let (update_tx, mut update_rx) = channel(128);
   153 
   221 
   154         loop {
   222         async fn accept_plain_branch(
   155             tokio::select! {
   223             layer: &mut NetworkLayer,
   156                 Ok((stream, addr)) = self.listener.accept() => {
   224             value: (TcpStream, SocketAddr),
   157                     if let Some(client) = self.create_client(stream, addr).await {
   225             update_tx: Sender<ClientUpdate>,
       
   226         ) {
       
   227             let (stream, addr) = value;
       
   228             if let Some(client) = layer.create_client(ClientStream::Tcp(stream), addr).await {
       
   229                 tokio::spawn(client.run(update_tx));
       
   230             }
       
   231         }
       
   232 
       
   233         #[cfg(feature = "tls-connections")]
       
   234         async fn accept_tls_branch(
       
   235             layer: &mut NetworkLayer,
       
   236             value: (TcpStream, SocketAddr),
       
   237             update_tx: Sender<ClientUpdate>,
       
   238         ) {
       
   239             let (stream, addr) = value;
       
   240             match layer.tls.acceptor.accept(stream).await {
       
   241                 Ok(stream) => {
       
   242                     if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await
       
   243                     {
   158                         tokio::spawn(client.run(update_tx.clone()));
   244                         tokio::spawn(client.run(update_tx.clone()));
   159                     }
   245                     }
   160                 }
   246                 }
   161                 client_message = update_rx.recv(), if !self.clients.is_empty() => {
   247                 Err(e) => {
   162                     use ClientUpdateData::*;
   248                     warn!("Unable to establish TLS connection: {}", e);
   163                     match client_message {
   249                 }
   164                         Some(ClientUpdate{ client_id, data: Message(message) } ) => {
   250             }
   165                             self.handle_message(client_id, message).await;
   251         }
   166                         }
   252 
   167                         Some(ClientUpdate{ client_id, data: Error(e) } ) => {
   253         async fn client_message_branch(
   168                             let mut response = handlers::Response::new(client_id);
   254             layer: &mut NetworkLayer,
   169                             info!("Client {} error: {:?}", client_id, e);
   255             client_message: Option<ClientUpdate>,
   170                             response.remove_client(client_id);
   256         ) {
   171                             handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
   257             use ClientUpdateData::*;
   172                             self.handle_response(response).await;
   258             match client_message {
   173                         }
   259                 Some(ClientUpdate {
   174                         None => unreachable!()
   260                     client_id,
   175                     }
   261                     data: Message(message),
   176                 }
   262                 }) => {
       
   263                     layer.handle_message(client_id, message).await;
       
   264                 }
       
   265                 Some(ClientUpdate {
       
   266                     client_id,
       
   267                     data: Error(e),
       
   268                 }) => {
       
   269                     let mut response = handlers::Response::new(client_id);
       
   270                     info!("Client {} error: {:?}", client_id, e);
       
   271                     response.remove_client(client_id);
       
   272                     handlers::handle_client_loss(&mut layer.server_state, client_id, &mut response);
       
   273                     layer.handle_response(response).await;
       
   274                 }
       
   275                 None => unreachable!(),
       
   276             }
       
   277         }
       
   278 
       
   279         loop {
       
   280             #[cfg(not(feature = "tls-connections"))]
       
   281             tokio::select! {
       
   282                 Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
       
   283                 client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
       
   284             }
       
   285 
       
   286             #[cfg(feature = "tls-connections")]
       
   287             tokio::select! {
       
   288                 Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
       
   289                 Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await,
       
   290                 client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
   177             }
   291             }
   178         }
   292         }
   179     }
   293     }
   180 
   294 
   181     async fn create_client(
   295     async fn create_client(
   182         &mut self,
   296         &mut self,
   183         stream: TcpStream,
   297         stream: ClientStream,
   184         addr: SocketAddr,
   298         addr: SocketAddr,
   185     ) -> Option<NetworkClient> {
   299     ) -> Option<NetworkClient> {
   186         let entry = self.clients.vacant_entry();
   300         let entry = self.clients.vacant_entry();
   187         let client_id = entry.key();
   301         let client_id = entry.key();
   188         let (tx, rx) = channel(16);
   302         let (tx, rx) = channel(16);
   261     }
   375     }
   262 }
   376 }
   263 
   377 
   264 pub struct NetworkLayerBuilder {
   378 pub struct NetworkLayerBuilder {
   265     listener: Option<TcpListener>,
   379     listener: Option<TcpListener>,
       
   380     #[cfg(feature = "tls-connections")]
       
   381     tls_listener: Option<TcpListener>,
       
   382     #[cfg(feature = "tls-connections")]
       
   383     tls_acceptor: Option<TlsAcceptor>,
   266     clients_capacity: usize,
   384     clients_capacity: usize,
   267     rooms_capacity: usize,
   385     rooms_capacity: usize,
   268 }
   386 }
   269 
   387 
   270 impl Default for NetworkLayerBuilder {
   388 impl Default for NetworkLayerBuilder {
   271     fn default() -> Self {
   389     fn default() -> Self {
   272         Self {
   390         Self {
   273             clients_capacity: 1024,
   391             clients_capacity: 1024,
   274             rooms_capacity: 512,
   392             rooms_capacity: 512,
   275             listener: None,
   393             listener: None,
       
   394             #[cfg(feature = "tls-connections")]
       
   395             tls_listener: None,
       
   396             #[cfg(feature = "tls-connections")]
       
   397             tls_acceptor: None,
   276         }
   398         }
   277     }
   399     }
   278 }
   400 }
   279 
   401 
   280 impl NetworkLayerBuilder {
   402 impl NetworkLayerBuilder {
   283             listener: Some(listener),
   405             listener: Some(listener),
   284             ..self
   406             ..self
   285         }
   407         }
   286     }
   408     }
   287 
   409 
       
   410     #[cfg(feature = "tls-connections")]
       
   411     pub fn with_tls_acceptor(self, listener: TlsAcceptor) -> Self {
       
   412         Self {
       
   413             tls_acceptor: Option::from(listener),
       
   414             ..self
       
   415         }
       
   416     }
       
   417 
       
   418     #[cfg(feature = "tls-connections")]
       
   419     pub fn with_tls_listener(self, listener: TlsAcceptor) -> Self {
       
   420         Self {
       
   421             tls_acceptor: Option::from(listener),
       
   422             ..self
       
   423         }
       
   424     }
       
   425 
   288     pub fn build(self) -> NetworkLayer {
   426     pub fn build(self) -> NetworkLayer {
   289         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
   427         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
   290 
   428 
   291         let clients = Slab::with_capacity(self.clients_capacity);
   429         let clients = Slab::with_capacity(self.clients_capacity);
   292 
   430 
   293         NetworkLayer {
   431         NetworkLayer {
   294             listener: self.listener.expect("No listener provided"),
   432             listener: self.listener.expect("No listener provided"),
       
   433             #[cfg(feature = "tls-connections")]
       
   434             tls: TlsListener {
       
   435                 listener: self.tls_listener.expect("No TLS listener provided"),
       
   436                 acceptor: self.tls_acceptor.expect("No TLS acceptor provided"),
       
   437             },
   295             server_state,
   438             server_state,
   296             clients,
   439             clients,
   297         }
   440         }
   298     }
   441     }
   299 }
   442 }