fix connection errors carrying over to new clients
authoralfadur
Wed, 14 Feb 2024 02:07:35 +0300
changeset 16019 c40f5e27aaf0
parent 16018 fb389df02e3e
child 16020 00bf5adba849
fix connection errors carrying over to new clients
rust/hedgewars-server/src/handlers/inanteroom.rs
rust/hedgewars-server/src/handlers/inlobby.rs
rust/hedgewars-server/src/server/network.rs
--- a/rust/hedgewars-server/src/handlers/inanteroom.rs	Tue Feb 13 00:58:17 2024 +0300
+++ b/rust/hedgewars-server/src/handlers/inanteroom.rs	Wed Feb 14 02:07:35 2024 +0300
@@ -60,6 +60,7 @@
     response: &mut super::Response,
     message: HwProtocolMessage,
 ) -> LoginResult {
+    //todo!("Handle parsing of empty nicks")
     match message {
         HwProtocolMessage::Quit(_) => {
             response.add(Bye("User quit".to_string()).send_self());
--- a/rust/hedgewars-server/src/handlers/inlobby.rs	Tue Feb 13 00:58:17 2024 +0300
+++ b/rust/hedgewars-server/src/handlers/inlobby.rs	Wed Feb 14 02:07:35 2024 +0300
@@ -27,7 +27,7 @@
     use hedgewars_network_protocol::messages::HwProtocolMessage::*;
 
     //todo!("add kick/ban handlers");
-    //todo!("add kick/ban handling");
+    //todo!("add command for forwarding lobby chat into rooms
 
     match message {
         CreateRoom(name, password) => match server.create_room(client_id, name, password) {
--- a/rust/hedgewars-server/src/server/network.rs	Tue Feb 13 00:58:17 2024 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Wed Feb 14 02:07:35 2024 +0300
@@ -30,11 +30,13 @@
 
 const PING_TIMEOUT: Duration = Duration::from_secs(15);
 
+#[derive(Debug)]
 enum ClientUpdateData {
     Message(HwProtocolMessage),
     Error(String),
 }
 
+#[derive(Debug)]
 struct ClientUpdate {
     client_id: ClientId,
     data: ClientUpdateData,
@@ -187,6 +189,7 @@
                         }
                         Err(e) => {
                             //todo!("send cmdline errors");
+                            //todo!("more graceful shutdown to prevent errors from explicitly closed clients")
                             sender.send(Error(format!("{}", e))).await;
                             if matches!(e, ProtocolError::Timeout) {
                                 Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
@@ -212,12 +215,12 @@
     tls: TlsListener,
     server_state: ServerState,
     clients: Slab<Sender<Bytes>>,
+    update_tx: Sender<ClientUpdate>,
+    update_rx: Receiver<ClientUpdate>
 }
 
 impl NetworkLayer {
     pub async fn run(&mut self) {
-        let (update_tx, mut update_rx) = channel(128);
-
         async fn accept_plain_branch(
             layer: &mut NetworkLayer,
             value: (TcpStream, SocketAddr),
@@ -280,15 +283,15 @@
         loop {
             #[cfg(not(feature = "tls-connections"))]
             tokio::select! {
-                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
-                client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
+                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, self.update_tx.clone()).await,
+                client_message = self.update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
             }
 
             #[cfg(feature = "tls-connections")]
             tokio::select! {
-                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
-                Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await,
-                client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
+                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, self.update_tx.clone()).await,
+                Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, self.update_tx.clone()).await,
+                client_message = self.update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
             }
         }
     }
@@ -342,19 +345,24 @@
             return;
         }
 
+        for client_id in response.extract_removed_clients() {
+            if self.clients.contains(client_id) {
+                self.clients.remove(client_id);
+                if self.clients.is_empty() {
+                    let (update_tx, update_rx) = channel(128);
+                    self.update_rx = update_rx;
+                    self.update_tx = update_tx;
+                }
+            }
+            info!("Client {} removed", client_id);
+        }
+
         debug!("{} pending server messages", response.len());
         let output = response.extract_messages(&mut self.server_state.server);
         for (clients, message) in output {
             debug!("Message {:?} to {:?}", message, clients);
             Self::send_message(&mut self.clients, message, clients.iter().cloned()).await;
         }
-
-        for client_id in response.extract_removed_clients() {
-            if self.clients.contains(client_id) {
-                self.clients.remove(client_id);
-            }
-            info!("Client {} removed", client_id);
-        }
     }
 
     async fn send_message<I>(
@@ -428,6 +436,7 @@
         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
 
         let clients = Slab::with_capacity(self.clients_capacity);
+        let (update_tx, update_rx) = channel(128);
 
         NetworkLayer {
             listener: self.listener.expect("No listener provided"),
@@ -438,6 +447,8 @@
             },
             server_state,
             clients,
+            update_tx,
+            update_rx
         }
     }
 }