rust/hedgewars-server/src/protocol.rs
changeset 15831 7d0f747afcb8
parent 15804 747278149393
child 15832 a4d505a32879
equal deleted inserted replaced
15830:ea459da15b30 15831:7d0f747afcb8
       
     1 use bytes::{Buf, BufMut, BytesMut};
       
     2 use log::*;
       
     3 use std::{io, io::ErrorKind, marker::Unpin};
       
     4 use tokio::io::AsyncReadExt;
       
     5 
     1 use hedgewars_network_protocol::{
     6 use hedgewars_network_protocol::{
     2     messages::HwProtocolMessage,
     7     messages::HwProtocolMessage,
     3     parser::{malformed_message, message},
     8     parser::{malformed_message, message},
     4 };
     9 };
     5 use log::*;
       
     6 use netbuf;
       
     7 use std::io::{Read, Result};
       
     8 
    10 
     9 pub struct ProtocolDecoder {
    11 pub struct ProtocolDecoder {
    10     buf: netbuf::Buf,
    12     buffer: BytesMut,
    11     is_recovering: bool,
    13     is_recovering: bool,
    12 }
    14 }
    13 
    15 
    14 impl ProtocolDecoder {
    16 impl ProtocolDecoder {
    15     pub fn new() -> ProtocolDecoder {
    17     pub fn new() -> ProtocolDecoder {
    16         ProtocolDecoder {
    18         ProtocolDecoder {
    17             buf: netbuf::Buf::new(),
    19             buffer: BytesMut::with_capacity(1024),
    18             is_recovering: false,
    20             is_recovering: false,
    19         }
    21         }
    20     }
    22     }
    21 
    23 
    22     fn recover(&mut self) -> bool {
    24     fn recover(&mut self) -> bool {
    23         self.is_recovering = match malformed_message(&self.buf[..]) {
    25         self.is_recovering = match malformed_message(&self.buffer[..]) {
    24             Ok((tail, ())) => {
    26             Ok((tail, ())) => {
    25                 let length = tail.len();
    27                 let remaining = tail.len();
    26                 self.buf.consume(self.buf.len() - length);
    28                 self.buffer.advance(self.buffer.len() - remaining);
    27                 false
    29                 false
    28             }
    30             }
    29             _ => {
    31             _ => {
    30                 self.buf.consume(self.buf.len());
    32                 self.buffer.clear();
    31                 true
    33                 true
    32             }
    34             }
    33         };
    35         };
    34         !self.is_recovering
    36         !self.is_recovering
    35     }
    37     }
    36 
    38 
    37     pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> {
    39     fn extract_message(&mut self) -> Option<HwProtocolMessage> {
    38         let count = self.buf.read_from(stream)?;
    40         if !self.is_recovering || self.recover() {
    39         if count > 0 && self.is_recovering {
    41             match message(&self.buffer[..]) {
    40             self.recover();
    42                 Ok((tail, message)) => {
    41         }
    43                     let remaining = tail.len();
    42         Ok(count)
    44                     self.buffer.advance(self.buffer.len() - remaining);
    43     }
    45                     return Some(message);
    44 
    46                 }
    45     pub fn extract_messages(&mut self) -> Vec<HwProtocolMessage> {
    47                 Err(nom::Err::Incomplete(_)) => {}
    46         let mut messages = vec![];
    48                 Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
    47         if !self.is_recovering {
    49                     debug!("Invalid message: {:?}", e);
    48             while !self.buf.is_empty() {
    50                     self.recover();
    49                 match message(&self.buf[..]) {
       
    50                     Ok((tail, message)) => {
       
    51                         messages.push(message);
       
    52                         let length = tail.len();
       
    53                         self.buf.consume(self.buf.len() - length);
       
    54                     }
       
    55                     Err(nom::Err::Incomplete(_)) => break,
       
    56                     Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
       
    57                         debug!("Invalid message: {:?}", e);
       
    58                         if !self.recover() || self.buf.is_empty() {
       
    59                             break;
       
    60                         }
       
    61                     }
       
    62                 }
    51                 }
    63             }
    52             }
    64         }
    53         }
    65         messages
    54         None
       
    55     }
       
    56 
       
    57     pub async fn read_from<R: AsyncReadExt + Unpin>(
       
    58         &mut self,
       
    59         stream: &mut R,
       
    60     ) -> Option<HwProtocolMessage> {
       
    61         loop {
       
    62             if !self.buffer.has_remaining() {
       
    63                 let count = stream.read_buf(&mut self.buffer).await.ok()?;
       
    64                 if count == 0 {
       
    65                     return None;
       
    66                 }
       
    67             }
       
    68             while !self.buffer.is_empty() {
       
    69                 if let Some(result) = self.extract_message() {
       
    70                     return Some(result);
       
    71                 }
       
    72             }
       
    73         }
    66     }
    74     }
    67 }
    75 }