rust/hedgewars-server/src/protocol.rs
changeset 15831 7d0f747afcb8
parent 15804 747278149393
child 15832 a4d505a32879
--- a/rust/hedgewars-server/src/protocol.rs	Mon Jan 31 18:24:49 2022 +0300
+++ b/rust/hedgewars-server/src/protocol.rs	Tue Feb 01 02:23:35 2022 +0300
@@ -1,67 +1,75 @@
+use bytes::{Buf, BufMut, BytesMut};
+use log::*;
+use std::{io, io::ErrorKind, marker::Unpin};
+use tokio::io::AsyncReadExt;
+
 use hedgewars_network_protocol::{
     messages::HwProtocolMessage,
     parser::{malformed_message, message},
 };
-use log::*;
-use netbuf;
-use std::io::{Read, Result};
 
 pub struct ProtocolDecoder {
-    buf: netbuf::Buf,
+    buffer: BytesMut,
     is_recovering: bool,
 }
 
 impl ProtocolDecoder {
     pub fn new() -> ProtocolDecoder {
         ProtocolDecoder {
-            buf: netbuf::Buf::new(),
+            buffer: BytesMut::with_capacity(1024),
             is_recovering: false,
         }
     }
 
     fn recover(&mut self) -> bool {
-        self.is_recovering = match malformed_message(&self.buf[..]) {
+        self.is_recovering = match malformed_message(&self.buffer[..]) {
             Ok((tail, ())) => {
-                let length = tail.len();
-                self.buf.consume(self.buf.len() - length);
+                let remaining = tail.len();
+                self.buffer.advance(self.buffer.len() - remaining);
                 false
             }
             _ => {
-                self.buf.consume(self.buf.len());
+                self.buffer.clear();
                 true
             }
         };
         !self.is_recovering
     }
 
-    pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> {
-        let count = self.buf.read_from(stream)?;
-        if count > 0 && self.is_recovering {
-            self.recover();
-        }
-        Ok(count)
-    }
-
-    pub fn extract_messages(&mut self) -> Vec<HwProtocolMessage> {
-        let mut messages = vec![];
-        if !self.is_recovering {
-            while !self.buf.is_empty() {
-                match message(&self.buf[..]) {
-                    Ok((tail, message)) => {
-                        messages.push(message);
-                        let length = tail.len();
-                        self.buf.consume(self.buf.len() - length);
-                    }
-                    Err(nom::Err::Incomplete(_)) => break,
-                    Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
-                        debug!("Invalid message: {:?}", e);
-                        if !self.recover() || self.buf.is_empty() {
-                            break;
-                        }
-                    }
+    fn extract_message(&mut self) -> Option<HwProtocolMessage> {
+        if !self.is_recovering || self.recover() {
+            match message(&self.buffer[..]) {
+                Ok((tail, message)) => {
+                    let remaining = tail.len();
+                    self.buffer.advance(self.buffer.len() - remaining);
+                    return Some(message);
+                }
+                Err(nom::Err::Incomplete(_)) => {}
+                Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
+                    debug!("Invalid message: {:?}", e);
+                    self.recover();
                 }
             }
         }
-        messages
+        None
+    }
+
+    pub async fn read_from<R: AsyncReadExt + Unpin>(
+        &mut self,
+        stream: &mut R,
+    ) -> Option<HwProtocolMessage> {
+        loop {
+            if !self.buffer.has_remaining() {
+                let count = stream.read_buf(&mut self.buffer).await.ok()?;
+                if count == 0 {
+                    return None;
+                }
+            }
+            while !self.buffer.is_empty() {
+                if let Some(result) = self.extract_message() {
+                    return Some(result);
+                }
+            }
+        }
     }
 }