rust/hedgewars-server/src/protocol.rs
changeset 14800 add191d825f4
parent 14780 09d46ab83361
child 14801 f5d43f007970
--- a/rust/hedgewars-server/src/protocol.rs	Fri Apr 12 19:26:44 2019 +0300
+++ b/rust/hedgewars-server/src/protocol.rs	Fri Apr 12 22:36:54 2019 +0300
@@ -1,3 +1,5 @@
+use crate::protocol::parser::message;
+use log::*;
 use netbuf;
 use nom::{Err, ErrorKind, IResult};
 use std::io::{Read, Result};
@@ -10,6 +12,7 @@
 pub struct ProtocolDecoder {
     buf: netbuf::Buf,
     consumed: usize,
+    is_recovering: bool,
 }
 
 impl ProtocolDecoder {
@@ -17,26 +20,55 @@
         ProtocolDecoder {
             buf: netbuf::Buf::new(),
             consumed: 0,
+            is_recovering: false,
         }
     }
 
+    fn recover(&mut self) -> bool {
+        self.is_recovering = match parser::malformed_message(&self.buf[..]) {
+            Ok((tail, ())) => {
+                self.buf.consume(self.buf.len() - tail.len());
+                false
+            }
+            _ => {
+                self.buf.consume(self.buf.len());
+                true
+            }
+        };
+        !self.is_recovering
+    }
+
     pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> {
-        self.buf.read_from(stream)
+        let count = self.buf.read_from(stream)?;
+        if count > 0 && self.is_recovering {
+            self.recover();
+        }
+        Ok(count)
     }
 
     pub fn extract_messages(&mut self) -> Vec<messages::HWProtocolMessage> {
-        let parse_result = parser::extract_messages(&self.buf[..]);
-        match parse_result {
-            Ok((tail, msgs)) => {
-                self.consumed = self.buf.len() - self.consumed - tail.len();
-                msgs
+        let mut messages = vec![];
+        let mut consumed = 0;
+        if !self.is_recovering {
+            loop {
+                match parser::message(&self.buf[consumed..]) {
+                    Ok((tail, message)) => {
+                        messages.push(message);
+                        consumed += self.buf.len() - tail.len();
+                    }
+                    Err(nom::Err::Incomplete(_)) => break,
+                    Err(nom::Err::Failure(e)) | Err(nom::Err::Error(e)) => {
+                        debug!("Invalid message: {:?}", e);
+                        self.buf.consume(consumed);
+                        consumed = 0;
+                        if !self.recover() || self.buf.is_empty() {
+                            break;
+                        }
+                    }
+                }
             }
-            _ => unreachable!(),
         }
-    }
-
-    pub fn sweep(&mut self) {
-        self.buf.consume(self.consumed);
-        self.consumed = 0;
+        self.buf.consume(consumed);
+        messages
     }
 }