handle edge polling properly
authoralfadur
Thu, 14 Jun 2018 12:31:15 -0400
changeset 13419 28b314ad566d
parent 13417 236cc4cf2448
child 13420 0eedc17055a0
handle edge polling properly
gameServer2/src/main.rs
gameServer2/src/server/network.rs
--- a/gameServer2/src/main.rs	Sun Jun 10 19:01:50 2018 +0200
+++ b/gameServer2/src/main.rs	Thu Jun 14 12:31:15 2018 -0400
@@ -22,6 +22,7 @@
 mod protocol;
 
 use server::network::NetworkLayer;
+use std::time::Duration;
 
 fn main() {
     env_logger::init().unwrap();
@@ -38,7 +39,12 @@
     let mut events = Events::with_capacity(1024);
 
     loop {
-        poll.poll(&mut events, None).unwrap();
+        let timeout = if hw_network.has_pending_operations() {
+            Some(Duration::from_millis(1))
+        } else {
+            None
+        };
+        poll.poll(&mut events, timeout).unwrap();
 
         for event in events.iter() {
             if event.readiness() & Ready::readable() == Ready::readable() {
@@ -60,5 +66,6 @@
 //                }
 //            }
         }
+        hw_network.on_idle(&poll).unwrap();
     }
 }
--- a/gameServer2/src/server/network.rs	Sun Jun 10 19:01:50 2018 +0200
+++ b/gameServer2/src/server/network.rs	Thu Jun 14 12:31:15 2018 -0400
@@ -1,29 +1,43 @@
 extern crate slab;
 
-use std::io::ErrorKind;
-use mio::net::*;
-use super::server::{HWServer, PendingMessage, Destination};
-use super::client::ClientId;
+use std::{
+    io, io::{Error, ErrorKind, Write},
+    net::{SocketAddr, IpAddr, Ipv4Addr},
+    collections::VecDeque
+};
+
+use mio::{
+    net::{TcpStream, TcpListener},
+    Poll, PollOpt, Ready, Token
+};
+use netbuf;
 use slab::Slab;
 
-use mio::net::TcpStream;
-use mio::*;
-use std::io::Write;
-use std::io;
-use netbuf;
+use utils;
+use protocol::{ProtocolDecoder, messages::*};
+use super::{
+    server::{HWServer, PendingMessage, Destination},
+    client::ClientId
+};
+
+const MAX_BYTES_PER_READ: usize = 2048;
 
-use utils;
-use protocol::ProtocolDecoder;
-use protocol::messages::*;
-use std::net::SocketAddr;
+#[derive(PartialEq, Copy, Clone)]
+pub enum NetworkClientState {
+    Idle,
+    NeedsWrite,
+    NeedsRead,
+    Closed,
+}
+
+type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
 
 pub struct NetworkClient {
     id: ClientId,
     socket: TcpStream,
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
-    buf_out: netbuf::Buf,
-    closed: bool
+    buf_out: netbuf::Buf
 }
 
 impl NetworkClient {
@@ -31,14 +45,67 @@
         NetworkClient {
             id, socket, peer_addr,
             decoder: ProtocolDecoder::new(),
-            buf_out: netbuf::Buf::new(),
-            closed: false
+            buf_out: netbuf::Buf::new()
         }
     }
 
+    pub fn read_messages(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+        let mut bytes_read = 0;
+        let result = loop {
+            match self.decoder.read_from(&mut self.socket) {
+                Ok(bytes) => {
+                    debug!("Read {} bytes", bytes);
+                    bytes_read += bytes;
+                    if bytes == 0 {
+                        let result = if bytes_read == 0 {
+                            info!("EOF for client {} ({})", self.id, self.peer_addr);
+                            (Vec::new(), NetworkClientState::Closed)
+                        } else {
+                            (self.decoder.extract_messages(), NetworkClientState::NeedsRead)
+                        };
+                        break Ok(result);
+                    }
+                    else if bytes_read >= MAX_BYTES_PER_READ {
+                        break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead))
+                    }
+                }
+                Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
+                    let messages =  if bytes_read == 0 {
+                        Vec::new()
+                    } else {
+                        self.decoder.extract_messages()
+                    };
+                    break Ok((messages, NetworkClientState::Idle));
+                }
+                Err(error) =>
+                    break Err(error)
+            }
+        };
+        self.decoder.sweep();
+        result
+    }
+
+    pub fn flush(&mut self) -> NetworkResult<()> {
+        let result = loop {
+            match self.buf_out.write_to(&mut self.socket) {
+                Ok(bytes) if self.buf_out.is_empty() || bytes == 0 =>
+                    break Ok(((), NetworkClientState::Idle)),
+                Ok(bytes) =>
+                    (),
+                Err(ref error) if error.kind() == ErrorKind::Interrupted
+                    || error.kind() == ErrorKind::WouldBlock => {
+                    break Ok(((), NetworkClientState::NeedsWrite));
+                },
+                Err(error) =>
+                    break Err(error)
+            }
+        };
+        self.socket.flush()?;
+        result
+    }
+
     pub fn send_raw_msg(&mut self, msg: &[u8]) {
         self.buf_out.write(msg).unwrap();
-        self.flush();
     }
 
     pub fn send_string(&mut self, msg: &String) {
@@ -48,42 +115,22 @@
     pub fn send_msg(&mut self, msg: HWServerMessage) {
         self.send_string(&msg.to_raw_protocol());
     }
-
-    fn flush(&mut self) {
-        self.buf_out.write_to(&mut self.socket).unwrap();
-        self.socket.flush().unwrap();
-    }
-
-    pub fn read_messages(&mut self) -> io::Result<Vec<HWProtocolMessage>> {
-        let bytes_read = self.decoder.read_from(&mut self.socket)?;
-        debug!("Read {} bytes", bytes_read);
-
-        if bytes_read == 0 {
-            self.closed = true;
-            info!("EOF for client {} ({})", self.id, self.peer_addr);
-        }
-
-        Ok(self.decoder.extract_messages())
-    }
-
-    pub fn write_messages(&mut self) -> io::Result<()> {
-        self.buf_out.write_to(&mut self.socket)?;
-        Ok(())
-    }
 }
 
 pub struct NetworkLayer {
     listener: TcpListener,
     server: HWServer,
 
-    clients: Slab<NetworkClient>
+    clients: Slab<NetworkClient>,
+    pending: VecDeque<(ClientId, NetworkClientState)>
 }
 
 impl NetworkLayer {
     pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
         let server = HWServer::new(clients_limit, rooms_limit);
         let clients = Slab::with_capacity(clients_limit);
-        NetworkLayer {listener, server, clients}
+        let pending = VecDeque::with_capacity(clients_limit);
+        NetworkLayer {listener, server, clients, pending}
     }
 
     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
@@ -93,7 +140,7 @@
 
     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
         let mut client_exists = false;
-        if let Some(ref client) = self.clients.get_mut(id) {
+        if let Some(ref client) = self.clients.get(id) {
             poll.deregister(&client.socket)
                 .ok().expect("could not deregister socket");
             info!("client {} ({}) removed", client.id, client.peer_addr);
@@ -116,6 +163,29 @@
         entry.insert(client);
     }
 
+    fn flush_server_messages(&mut self) {
+        debug!("{} pending server messages", self.server.output.len());
+        for PendingMessage(destination, msg) in self.server.output.drain(..) {
+            match destination {
+                Destination::ToSelf(id)  => {
+                    if let Some(ref mut client) = self.clients.get_mut(id) {
+                        client.send_msg(msg);
+                        self.pending.push_back((id, NetworkClientState::NeedsWrite));
+                    }
+                }
+                Destination::ToOthers(id) => {
+                    let msg_string = msg.to_raw_protocol();
+                    for (client_id, client) in self.clients.iter_mut() {
+                        if client_id != id {
+                            client.send_string(&msg_string);
+                            self.pending.push_back((client_id, NetworkClientState::NeedsWrite));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
         let (client_socket, addr) = self.listener.accept()?;
         info!("Connected: {}", addr);
@@ -127,56 +197,48 @@
         Ok(())
     }
 
-    fn flush_server_messages(&mut self) {
-        for PendingMessage(destination, msg) in self.server.output.drain(..) {
-            match destination {
-                Destination::ToSelf(id)  => {
-                    if let Some(ref mut client) = self.clients.get_mut(id) {
-                        client.send_msg(msg)
-                    }
-                }
-                Destination::ToOthers(id) => {
-                    let msg_string = msg.to_raw_protocol();
-                    for item in self.clients.iter_mut() {
-                        if item.0 != id {
-                            item.1.send_string(&msg_string)
-                        }
-                    }
-                }
-            }
-        }
+    fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: Error, msg: &str) -> io::Result<()> {
+        let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+            client.peer_addr
+        } else {
+            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
+        };
+        debug!("{}({}): {}", msg, addr, error);
+        self.client_error(poll, client_id)
     }
 
     pub fn client_readable(&mut self, poll: &Poll,
                            client_id: ClientId) -> io::Result<()> {
-        let mut client_lost = false;
-        let messages;
-        if let Some(ref mut client) = self.clients.get_mut(client_id) {
-            messages = match client.read_messages() {
-                Ok(messages) => Some(messages),
-                Err(ref error) if error.kind() == ErrorKind::WouldBlock => None,
-                Err(error) => return Err(error)
+        let messages =
+            if let Some(ref mut client) = self.clients.get_mut(client_id) {
+                client.read_messages()
+            } else {
+                warn!("invalid readable client: {}", client_id);
+                Ok((Vec::new(), NetworkClientState::Idle))
             };
-            if client.closed {
-                client_lost = true;
+
+        match messages {
+            Ok((messages, state)) => {
+                for message in messages {
+                    self.server.handle_msg(client_id, message);
+                }
+                match state {
+                    NetworkClientState::NeedsRead =>
+                        self.pending.push_back((client_id, state)),
+                    NetworkClientState::Closed =>
+                        self.client_error(&poll, client_id)?,
+                    _ => {}
+                };
             }
-        } else {
-            warn!("invalid readable client: {}", client_id);
-            messages = None;
-        };
-
-        if client_lost {
-            self.client_error(&poll, client_id)?;
-        } else if let Some(msg) = messages {
-            for message in msg {
-                self.server.handle_msg(client_id, message);
-            }
-            self.flush_server_messages();
+            Err(e) => self.operation_failed(
+                poll, client_id, e,
+                "Error while reading from client socket")?
         }
 
+        self.flush_server_messages();
+
         if !self.server.removed_clients.is_empty() {
-            let ids = self.server.removed_clients.to_vec();
-            self.server.removed_clients.clear();
+            let ids: Vec<_> = self.server.removed_clients.drain(..).collect();
             for client_id in ids {
                 self.deregister_client(poll, client_id);
             }
@@ -187,14 +249,22 @@
 
     pub fn client_writable(&mut self, poll: &Poll,
                            client_id: ClientId) -> io::Result<()> {
-        if let Some(ref mut client) = self.clients.get_mut(client_id) {
-            match client.write_messages() {
-                Ok(_) => (),
-                Err(ref error) if error.kind() == ErrorKind::WouldBlock => (),
-                Err(error) => return Err(error)
-            }
-        } else {
-            warn!("invalid writable client: {}", client_id);
+        let result =
+            if let Some(ref mut client) = self.clients.get_mut(client_id) {
+                client.flush()
+            } else {
+                warn!("invalid writable client: {}", client_id);
+                Ok(((), NetworkClientState::Idle))
+            };
+
+        match result {
+            Ok(((), state)) if state == NetworkClientState::NeedsWrite =>
+                self.pending.push_back((client_id, state)),
+            Ok(_) =>
+                {}
+            Err(e) => self.operation_failed(
+                poll, client_id, e,
+                "Error while writing to client socket")?
         }
 
         Ok(())
@@ -207,5 +277,21 @@
 
         Ok(())
     }
+
+    pub fn has_pending_operations(&self) -> bool {
+        !self.pending.is_empty()
+    }
+
+    pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
+        while let Some((id, state)) = self.pending.pop_front() {
+            match state {
+                NetworkClientState::NeedsRead =>
+                    self.client_readable(poll, id)?,
+                NetworkClientState::NeedsWrite =>
+                    self.client_writable(poll, id)?,
+                _ => {}
+            }
+        }
+        Ok(())
+    }
 }
-