|
1 use bytes::{Buf, BufMut, BytesMut}; |
|
2 use log::*; |
|
3 use std::{io, io::ErrorKind, marker::Unpin}; |
|
4 use tokio::io::AsyncReadExt; |
|
5 |
1 use hedgewars_network_protocol::{ |
6 use hedgewars_network_protocol::{ |
2 messages::HwProtocolMessage, |
7 messages::HwProtocolMessage, |
3 parser::{malformed_message, message}, |
8 parser::{malformed_message, message}, |
4 }; |
9 }; |
5 use log::*; |
|
6 use netbuf; |
|
7 use std::io::{Read, Result}; |
|
8 |
10 |
9 pub struct ProtocolDecoder { |
11 pub struct ProtocolDecoder { |
10 buf: netbuf::Buf, |
12 buffer: BytesMut, |
11 is_recovering: bool, |
13 is_recovering: bool, |
12 } |
14 } |
13 |
15 |
14 impl ProtocolDecoder { |
16 impl ProtocolDecoder { |
15 pub fn new() -> ProtocolDecoder { |
17 pub fn new() -> ProtocolDecoder { |
16 ProtocolDecoder { |
18 ProtocolDecoder { |
17 buf: netbuf::Buf::new(), |
19 buffer: BytesMut::with_capacity(1024), |
18 is_recovering: false, |
20 is_recovering: false, |
19 } |
21 } |
20 } |
22 } |
21 |
23 |
22 fn recover(&mut self) -> bool { |
24 fn recover(&mut self) -> bool { |
23 self.is_recovering = match malformed_message(&self.buf[..]) { |
25 self.is_recovering = match malformed_message(&self.buffer[..]) { |
24 Ok((tail, ())) => { |
26 Ok((tail, ())) => { |
25 let length = tail.len(); |
27 let remaining = tail.len(); |
26 self.buf.consume(self.buf.len() - length); |
28 self.buffer.advance(self.buffer.len() - remaining); |
27 false |
29 false |
28 } |
30 } |
29 _ => { |
31 _ => { |
30 self.buf.consume(self.buf.len()); |
32 self.buffer.clear(); |
31 true |
33 true |
32 } |
34 } |
33 }; |
35 }; |
34 !self.is_recovering |
36 !self.is_recovering |
35 } |
37 } |
36 |
38 |
37 pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> { |
39 fn extract_message(&mut self) -> Option<HwProtocolMessage> { |
38 let count = self.buf.read_from(stream)?; |
40 if !self.is_recovering || self.recover() { |
39 if count > 0 && self.is_recovering { |
41 match message(&self.buffer[..]) { |
40 self.recover(); |
42 Ok((tail, message)) => { |
41 } |
43 let remaining = tail.len(); |
42 Ok(count) |
44 self.buffer.advance(self.buffer.len() - remaining); |
43 } |
45 return Some(message); |
44 |
46 } |
45 pub fn extract_messages(&mut self) -> Vec<HwProtocolMessage> { |
47 Err(nom::Err::Incomplete(_)) => {} |
46 let mut messages = vec![]; |
48 Err(nom::Err::Failure(e) | nom::Err::Error(e)) => { |
47 if !self.is_recovering { |
49 debug!("Invalid message: {:?}", e); |
48 while !self.buf.is_empty() { |
50 self.recover(); |
49 match message(&self.buf[..]) { |
|
50 Ok((tail, message)) => { |
|
51 messages.push(message); |
|
52 let length = tail.len(); |
|
53 self.buf.consume(self.buf.len() - length); |
|
54 } |
|
55 Err(nom::Err::Incomplete(_)) => break, |
|
56 Err(nom::Err::Failure(e) | nom::Err::Error(e)) => { |
|
57 debug!("Invalid message: {:?}", e); |
|
58 if !self.recover() || self.buf.is_empty() { |
|
59 break; |
|
60 } |
|
61 } |
|
62 } |
51 } |
63 } |
52 } |
64 } |
53 } |
65 messages |
54 None |
|
55 } |
|
56 |
|
57 pub async fn read_from<R: AsyncReadExt + Unpin>( |
|
58 &mut self, |
|
59 stream: &mut R, |
|
60 ) -> Option<HwProtocolMessage> { |
|
61 loop { |
|
62 if !self.buffer.has_remaining() { |
|
63 let count = stream.read_buf(&mut self.buffer).await.ok()?; |
|
64 if count == 0 { |
|
65 return None; |
|
66 } |
|
67 } |
|
68 while !self.buffer.is_empty() { |
|
69 if let Some(result) = self.extract_message() { |
|
70 return Some(result); |
|
71 } |
|
72 } |
|
73 } |
66 } |
74 } |
67 } |
75 } |