28 handlers, |
26 handlers, |
29 handlers::{IoResult, IoTask, ServerState}, |
27 handlers::{IoResult, IoTask, ServerState}, |
30 protocol::ProtocolDecoder, |
28 protocol::ProtocolDecoder, |
31 utils, |
29 utils, |
32 }; |
30 }; |
33 use hedgewars_network_protocol::{messages::HwServerMessage::Redirect, messages::*}; |
31 use hedgewars_network_protocol::{ |
34 |
32 messages::HwServerMessage::Redirect, messages::*, parser::server_message, |
35 #[cfg(feature = "official-server")] |
33 }; |
36 use super::io::{IoThread, RequestId}; |
34 use tokio::io::AsyncWriteExt; |
37 |
35 |
38 #[cfg(feature = "tls-connections")] |
36 enum ClientUpdateData { |
39 use openssl::{ |
37 Message(HwProtocolMessage), |
40 error::ErrorStack, |
38 Error(String), |
41 ssl::{ |
39 } |
42 HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, |
40 |
43 SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, |
41 struct ClientUpdate { |
44 }, |
42 client_id: ClientId, |
45 }; |
43 data: ClientUpdateData, |
46 |
44 } |
47 const MAX_BYTES_PER_READ: usize = 2048; |
45 |
48 const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5); |
46 struct ClientUpdateSender { |
49 const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5); |
47 client_id: ClientId, |
50 const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize; |
48 sender: Sender<ClientUpdate>, |
51 const PING_PROBES_COUNT: u8 = 2; |
49 } |
52 |
50 |
53 #[derive(Hash, Eq, PartialEq, Copy, Clone)] |
51 impl ClientUpdateSender { |
54 pub enum NetworkClientState { |
52 async fn send(&mut self, data: ClientUpdateData) -> bool { |
55 Idle, |
53 self.sender |
56 NeedsWrite, |
54 .send(ClientUpdate { |
57 NeedsRead, |
55 client_id: self.client_id, |
58 Closed, |
56 data, |
59 #[cfg(feature = "tls-connections")] |
57 }) |
60 Connected, |
58 .await |
61 } |
59 .is_ok() |
62 |
60 } |
63 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
61 } |
64 |
62 |
65 pub enum ClientSocket { |
63 struct NetworkClient { |
66 Plain(TcpStream), |
|
67 #[cfg(feature = "tls-connections")] |
|
68 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
|
69 #[cfg(feature = "tls-connections")] |
|
70 SslStream(SslStream<TcpStream>), |
|
71 } |
|
72 |
|
73 impl ClientSocket { |
|
74 fn inner_mut(&mut self) -> &mut TcpStream { |
|
75 match self { |
|
76 ClientSocket::Plain(stream) => stream, |
|
77 #[cfg(feature = "tls-connections")] |
|
78 ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(), |
|
79 #[cfg(feature = "tls-connections")] |
|
80 ClientSocket::SslHandshake(None) => unreachable!(), |
|
81 #[cfg(feature = "tls-connections")] |
|
82 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(), |
|
83 } |
|
84 } |
|
85 } |
|
86 |
|
87 pub struct NetworkClient { |
|
88 id: ClientId, |
64 id: ClientId, |
89 socket: ClientSocket, |
65 socket: TcpStream, |
|
66 receiver: Receiver<Bytes>, |
90 peer_addr: SocketAddr, |
67 peer_addr: SocketAddr, |
91 decoder: ProtocolDecoder, |
68 decoder: ProtocolDecoder, |
92 buf_out: netbuf::Buf, |
|
93 pending_close: bool, |
|
94 timeout: Timeout, |
|
95 last_rx_time: Instant, |
|
96 } |
69 } |
97 |
70 |
98 impl NetworkClient { |
71 impl NetworkClient { |
99 pub fn new( |
72 fn new( |
100 id: ClientId, |
73 id: ClientId, |
101 socket: ClientSocket, |
74 socket: TcpStream, |
102 peer_addr: SocketAddr, |
75 peer_addr: SocketAddr, |
103 timeout: Timeout, |
76 receiver: Receiver<Bytes>, |
104 ) -> NetworkClient { |
77 ) -> Self { |
105 NetworkClient { |
78 Self { |
106 id, |
79 id, |
107 socket, |
80 socket, |
108 peer_addr, |
81 peer_addr, |
|
82 receiver, |
109 decoder: ProtocolDecoder::new(), |
83 decoder: ProtocolDecoder::new(), |
110 buf_out: netbuf::Buf::new(), |
84 } |
111 pending_close: false, |
85 } |
112 timeout, |
86 |
113 last_rx_time: Instant::now(), |
87 async fn read(&mut self) -> Option<HwProtocolMessage> { |
114 } |
88 self.decoder.read_from(&mut self.socket).await |
115 } |
89 } |
116 |
90 |
117 #[cfg(feature = "tls-connections")] |
91 async fn write(&mut self, mut data: Bytes) -> bool { |
118 fn handshake_impl( |
92 !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0) |
119 &mut self, |
93 } |
120 handshake: MidHandshakeSslStream<TcpStream>, |
94 |
121 ) -> io::Result<NetworkClientState> { |
95 async fn run(mut self, sender: Sender<ClientUpdate>) { |
122 match handshake.handshake() { |
96 use ClientUpdateData::*; |
123 Ok(stream) => { |
97 let mut sender = ClientUpdateSender { |
124 self.socket = ClientSocket::SslStream(stream); |
98 client_id: self.id, |
125 debug!( |
99 sender, |
126 "TLS handshake with {} ({}) completed", |
|
127 self.id, self.peer_addr |
|
128 ); |
|
129 Ok(NetworkClientState::Connected) |
|
130 } |
|
131 Err(HandshakeError::WouldBlock(new_handshake)) => { |
|
132 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
|
133 Ok(NetworkClientState::Idle) |
|
134 } |
|
135 Err(HandshakeError::Failure(new_handshake)) => { |
|
136 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
|
137 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); |
|
138 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
139 } |
|
140 Err(HandshakeError::SetupFailure(_)) => unreachable!(), |
|
141 } |
|
142 } |
|
143 |
|
144 fn read_impl<R: Read>( |
|
145 decoder: &mut ProtocolDecoder, |
|
146 source: &mut R, |
|
147 id: ClientId, |
|
148 addr: &SocketAddr, |
|
149 ) -> NetworkResult<Vec<HwProtocolMessage>> { |
|
150 let mut bytes_read = 0; |
|
151 let result = loop { |
|
152 match decoder.read_from(source) { |
|
153 Ok(bytes) => { |
|
154 debug!("Client {}: read {} bytes", id, bytes); |
|
155 bytes_read += bytes; |
|
156 if bytes == 0 { |
|
157 let result = if bytes_read == 0 { |
|
158 info!("EOF for client {} ({})", id, addr); |
|
159 (Vec::new(), NetworkClientState::Closed) |
|
160 } else { |
|
161 (decoder.extract_messages(), NetworkClientState::NeedsRead) |
|
162 }; |
|
163 break Ok(result); |
|
164 } else if bytes_read >= MAX_BYTES_PER_READ { |
|
165 break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)); |
|
166 } |
|
167 } |
|
168 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
|
169 let messages = if bytes_read == 0 { |
|
170 Vec::new() |
|
171 } else { |
|
172 decoder.extract_messages() |
|
173 }; |
|
174 break Ok((messages, NetworkClientState::Idle)); |
|
175 } |
|
176 Err(error) => break Err(error), |
|
177 } |
|
178 }; |
100 }; |
179 result |
101 |
180 } |
102 loop { |
181 |
103 tokio::select! { |
182 pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> { |
104 server_message = self.receiver.recv() => { |
183 let result = match self.socket { |
105 match server_message { |
184 ClientSocket::Plain(ref mut stream) => { |
106 Some(message) => if !self.write(message).await { |
185 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
107 sender.send(Error("Connection reset by peer".to_string())).await; |
186 } |
108 break; |
187 #[cfg(feature = "tls-connections")] |
109 } |
188 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
110 None => { |
189 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
111 break; |
190 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
112 } |
191 } |
113 } |
192 #[cfg(feature = "tls-connections")] |
114 } |
193 ClientSocket::SslStream(ref mut stream) => { |
115 client_message = self.decoder.read_from(&mut self.socket) => { |
194 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
116 match client_message { |
195 } |
117 Some(message) => { |
196 }; |
118 if !sender.send(Message(message)).await { |
197 |
119 break; |
198 if let Ok(_) = result { |
120 } |
199 self.last_rx_time = Instant::now(); |
121 } |
200 } |
122 None => { |
201 |
123 sender.send(Error("Connection reset by peer".to_string())).await; |
202 result |
124 break; |
203 } |
125 } |
204 |
126 } |
205 fn write_impl<W: Write>( |
127 } |
206 buf_out: &mut netbuf::Buf, |
128 } |
207 destination: &mut W, |
129 } |
208 close_on_empty: bool, |
130 } |
209 ) -> NetworkResult<()> { |
131 } |
210 let result = loop { |
|
211 match buf_out.write_to(destination) { |
|
212 Ok(bytes) if buf_out.is_empty() || bytes == 0 => { |
|
213 let status = if buf_out.is_empty() && close_on_empty { |
|
214 NetworkClientState::Closed |
|
215 } else { |
|
216 NetworkClientState::Idle |
|
217 }; |
|
218 break Ok(((), status)); |
|
219 } |
|
220 Ok(_) => (), |
|
221 Err(ref error) |
|
222 if error.kind() == ErrorKind::Interrupted |
|
223 || error.kind() == ErrorKind::WouldBlock => |
|
224 { |
|
225 break Ok(((), NetworkClientState::NeedsWrite)); |
|
226 } |
|
227 Err(error) => break Err(error), |
|
228 } |
|
229 }; |
|
230 result |
|
231 } |
|
232 |
|
233 pub fn write(&mut self) -> NetworkResult<()> { |
|
234 let result = match self.socket { |
|
235 ClientSocket::Plain(ref mut stream) => { |
|
236 NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close) |
|
237 } |
|
238 #[cfg(feature = "tls-connections")] |
|
239 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
|
240 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
|
241 Ok(((), self.handshake_impl(handshake)?)) |
|
242 } |
|
243 #[cfg(feature = "tls-connections")] |
|
244 ClientSocket::SslStream(ref mut stream) => { |
|
245 NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close) |
|
246 } |
|
247 }; |
|
248 |
|
249 self.socket.inner_mut().flush()?; |
|
250 result |
|
251 } |
|
252 |
|
253 pub fn send_raw_msg(&mut self, msg: &[u8]) { |
|
254 self.buf_out.write_all(msg).unwrap(); |
|
255 } |
|
256 |
|
257 pub fn send_string(&mut self, msg: &str) { |
|
258 self.send_raw_msg(&msg.as_bytes()); |
|
259 } |
|
260 |
|
261 pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout { |
|
262 replace(&mut self.timeout, timeout) |
|
263 } |
|
264 |
|
265 pub fn has_pending_sends(&self) -> bool { |
|
266 !self.buf_out.is_empty() |
|
267 } |
|
268 } |
|
269 |
|
270 #[cfg(feature = "tls-connections")] |
|
271 struct ServerSsl { |
|
272 listener: TcpListener, |
|
273 context: SslContext, |
|
274 } |
|
275 |
|
276 #[cfg(feature = "official-server")] |
|
277 pub struct IoLayer { |
|
278 next_request_id: RequestId, |
|
279 request_queue: Vec<(RequestId, ClientId)>, |
|
280 io_thread: IoThread, |
|
281 } |
|
282 |
|
283 #[cfg(feature = "official-server")] |
|
284 impl IoLayer { |
|
285 fn new(waker: Waker) -> Self { |
|
286 Self { |
|
287 next_request_id: 0, |
|
288 request_queue: vec![], |
|
289 io_thread: IoThread::new(waker), |
|
290 } |
|
291 } |
|
292 |
|
293 fn send(&mut self, client_id: ClientId, task: IoTask) { |
|
294 let request_id = self.next_request_id; |
|
295 self.next_request_id += 1; |
|
296 self.request_queue.push((request_id, client_id)); |
|
297 self.io_thread.send(request_id, task); |
|
298 } |
|
299 |
|
300 fn try_recv(&mut self) -> Option<(ClientId, IoResult)> { |
|
301 let (request_id, result) = self.io_thread.try_recv()?; |
|
302 if let Some(index) = self |
|
303 .request_queue |
|
304 .iter() |
|
305 .position(|(id, _)| *id == request_id) |
|
306 { |
|
307 let (_, client_id) = self.request_queue.swap_remove(index); |
|
308 Some((client_id, result)) |
|
309 } else { |
|
310 None |
|
311 } |
|
312 } |
|
313 |
|
314 fn cancel(&mut self, client_id: ClientId) { |
|
315 let mut index = 0; |
|
316 while index < self.request_queue.len() { |
|
317 if self.request_queue[index].1 == client_id { |
|
318 self.request_queue.swap_remove(index); |
|
319 } else { |
|
320 index += 1; |
|
321 } |
|
322 } |
|
323 } |
|
324 } |
|
325 |
|
326 enum TimeoutEvent { |
|
327 SendPing { probes_count: u8 }, |
|
328 DropClient, |
|
329 } |
|
330 |
|
331 struct TimerData(TimeoutEvent, ClientId); |
|
332 type NetworkTimeoutEvents = TimedEvents<TimerData, MAX_TIMEOUT>; |
|
333 |
132 |
334 pub struct NetworkLayer { |
133 pub struct NetworkLayer { |
335 listener: TcpListener, |
134 listener: TcpListener, |
336 server_state: ServerState, |
135 server_state: ServerState, |
337 clients: Slab<NetworkClient>, |
136 clients: Slab<Sender<Bytes>>, |
338 pending: HashSet<(ClientId, NetworkClientState)>, |
|
339 pending_cache: Vec<(ClientId, NetworkClientState)>, |
|
340 #[cfg(feature = "tls-connections")] |
|
341 ssl: ServerSsl, |
|
342 #[cfg(feature = "official-server")] |
|
343 io: IoLayer, |
|
344 timeout_events: NetworkTimeoutEvents, |
|
345 } |
|
346 |
|
347 fn register_read<S: Source>(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> { |
|
348 poll.registry().register(source, token, Interest::READABLE) |
|
349 } |
|
350 |
|
351 fn create_ping_timeout( |
|
352 timeout_events: &mut NetworkTimeoutEvents, |
|
353 probes_count: u8, |
|
354 client_id: ClientId, |
|
355 ) -> Timeout { |
|
356 timeout_events.set_timeout( |
|
357 NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(), |
|
358 TimerData(TimeoutEvent::SendPing { probes_count }, client_id), |
|
359 ) |
|
360 } |
|
361 |
|
362 fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout { |
|
363 timeout_events.set_timeout( |
|
364 NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(), |
|
365 TimerData(TimeoutEvent::DropClient, client_id), |
|
366 ) |
|
367 } |
137 } |
368 |
138 |
369 impl NetworkLayer { |
139 impl NetworkLayer { |
370 pub fn register(&mut self, poll: &Poll) -> io::Result<()> { |
140 pub async fn run(&mut self) { |
371 register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?; |
141 let (update_tx, mut update_rx) = channel(128); |
372 #[cfg(feature = "tls-connections")] |
142 |
373 register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?; |
143 loop { |
374 |
144 tokio::select! { |
375 Ok(()) |
145 Ok((stream, addr)) = self.listener.accept() => { |
376 } |
146 if let Some(client) = self.create_client(stream, addr).await { |
377 |
147 tokio::spawn(client.run(update_tx.clone())); |
378 fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) { |
148 } |
379 if let Some(ref mut client) = self.clients.get_mut(id) { |
149 } |
380 poll.registry() |
150 client_message = update_rx.recv(), if !self.clients.is_empty() => { |
381 .deregister(client.socket.inner_mut()) |
151 use ClientUpdateData::*; |
382 .expect("could not deregister socket"); |
152 match client_message { |
383 if client.has_pending_sends() && !is_error { |
153 Some(ClientUpdate{ client_id, data: Message(message) } ) => { |
384 info!( |
154 self.handle_message(client_id, message).await; |
385 "client {} ({}) pending removal", |
155 } |
386 client.id, client.peer_addr |
156 Some(ClientUpdate{ client_id, .. } ) => { |
387 ); |
157 let mut response = handlers::Response::new(client_id); |
388 client.pending_close = true; |
158 handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); |
389 poll.registry() |
159 self.handle_response(response).await; |
390 .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE) |
160 } |
391 .unwrap_or_else(|_| { |
161 None => unreachable!() |
392 self.clients.remove(id); |
162 } |
393 }); |
163 } |
394 } else { |
164 } |
395 info!("client {} ({}) removed", client.id, client.peer_addr); |
165 } |
396 self.clients.remove(id); |
166 } |
397 } |
167 |
398 #[cfg(feature = "official-server")] |
168 async fn create_client( |
399 self.io.cancel(id); |
|
400 } |
|
401 } |
|
402 |
|
403 fn register_client( |
|
404 &mut self, |
169 &mut self, |
405 poll: &Poll, |
170 stream: TcpStream, |
406 mut client_socket: ClientSocket, |
|
407 addr: SocketAddr, |
171 addr: SocketAddr, |
408 ) -> io::Result<ClientId> { |
172 ) -> Option<NetworkClient> { |
409 let entry = self.clients.vacant_entry(); |
173 let entry = self.clients.vacant_entry(); |
410 let client_id = entry.key(); |
174 let client_id = entry.key(); |
411 |
175 let (tx, rx) = channel(16); |
412 poll.registry().register( |
176 entry.insert(tx); |
413 client_socket.inner_mut(), |
177 |
414 Token(client_id), |
178 let client = NetworkClient::new(client_id, stream, addr, rx); |
415 Interest::READABLE | Interest::WRITABLE, |
179 |
416 )?; |
|
417 |
|
418 let client = NetworkClient::new( |
|
419 client_id, |
|
420 client_socket, |
|
421 addr, |
|
422 create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id), |
|
423 ); |
|
424 info!("client {} ({}) added", client.id, client.peer_addr); |
180 info!("client {} ({}) added", client.id, client.peer_addr); |
425 entry.insert(client); |
181 |
426 |
|
427 Ok(client_id) |
|
428 } |
|
429 |
|
430 fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { |
|
431 if response.is_empty() { |
|
432 return; |
|
433 } |
|
434 |
|
435 debug!("{} pending server messages", response.len()); |
|
436 let output = response.extract_messages(&mut self.server_state.server); |
|
437 for (clients, message) in output { |
|
438 debug!("Message {:?} to {:?}", message, clients); |
|
439 let msg_string = message.to_raw_protocol(); |
|
440 for client_id in clients { |
|
441 if let Some(client) = self.clients.get_mut(client_id) { |
|
442 client.send_string(&msg_string); |
|
443 self.pending |
|
444 .insert((client_id, NetworkClientState::NeedsWrite)); |
|
445 } |
|
446 } |
|
447 } |
|
448 |
|
449 for client_id in response.extract_removed_clients() { |
|
450 self.deregister_client(poll, client_id, false); |
|
451 } |
|
452 |
|
453 #[cfg(feature = "official-server")] |
|
454 { |
|
455 let client_id = response.client_id(); |
|
456 for task in response.extract_io_tasks() { |
|
457 self.io.send(client_id, task); |
|
458 } |
|
459 } |
|
460 } |
|
461 |
|
462 pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> { |
|
463 for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) { |
|
464 if let Some(client) = self.clients.get_mut(client_id) { |
|
465 if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT { |
|
466 match event { |
|
467 TimeoutEvent::SendPing { probes_count } => { |
|
468 client.send_string(&HwServerMessage::Ping.to_raw_protocol()); |
|
469 client.write()?; |
|
470 let timeout = if probes_count != 0 { |
|
471 create_ping_timeout( |
|
472 &mut self.timeout_events, |
|
473 probes_count - 1, |
|
474 client_id, |
|
475 ) |
|
476 } else { |
|
477 create_drop_timeout(&mut self.timeout_events, client_id) |
|
478 }; |
|
479 client.replace_timeout(timeout); |
|
480 } |
|
481 TimeoutEvent::DropClient => { |
|
482 client.send_string( |
|
483 &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(), |
|
484 ); |
|
485 let _res = client.write(); |
|
486 |
|
487 self.operation_failed( |
|
488 poll, |
|
489 client_id, |
|
490 &ErrorKind::TimedOut.into(), |
|
491 "No ping response", |
|
492 )?; |
|
493 } |
|
494 } |
|
495 } else { |
|
496 client.replace_timeout(create_ping_timeout( |
|
497 &mut self.timeout_events, |
|
498 PING_PROBES_COUNT - 1, |
|
499 client_id, |
|
500 )); |
|
501 } |
|
502 } |
|
503 } |
|
504 Ok(()) |
|
505 } |
|
506 |
|
507 #[cfg(feature = "official-server")] |
|
508 pub fn handle_io_result(&mut self, poll: &Poll) -> io::Result<()> { |
|
509 while let Some((client_id, result)) = self.io.try_recv() { |
|
510 debug!("Handling io result {:?} for client {}", result, client_id); |
|
511 let mut response = handlers::Response::new(client_id); |
|
512 handlers::handle_io_result(&mut self.server_state, client_id, &mut response, result); |
|
513 self.handle_response(response, poll); |
|
514 } |
|
515 Ok(()) |
|
516 } |
|
517 |
|
518 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
|
519 Ok(ClientSocket::Plain(socket)) |
|
520 } |
|
521 |
|
522 #[cfg(feature = "tls-connections")] |
|
523 fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
|
524 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
|
525 let mut builder = SslStreamBuilder::new(ssl, socket); |
|
526 builder.set_accept_state(); |
|
527 match builder.handshake() { |
|
528 Ok(stream) => Ok(ClientSocket::SslStream(stream)), |
|
529 Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))), |
|
530 Err(e) => { |
|
531 debug!("OpenSSL handshake failed: {}", e); |
|
532 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
533 } |
|
534 } |
|
535 } |
|
536 |
|
537 fn init_client(&mut self, poll: &Poll, client_id: ClientId) { |
|
538 let mut response = handlers::Response::new(client_id); |
182 let mut response = handlers::Response::new(client_id); |
539 |
183 |
540 if let ClientSocket::Plain(_) = self.clients[client_id].socket { |
184 let added = if let IpAddr::V4(addr) = client.peer_addr.ip() { |
541 #[cfg(feature = "tls-connections")] |
|
542 response.add(Redirect(self.ssl.listener.local_addr().unwrap().port()).send_self()) |
|
543 } |
|
544 |
|
545 if let IpAddr::V4(addr) = self.clients[client_id].peer_addr.ip() { |
|
546 handlers::handle_client_accept( |
185 handlers::handle_client_accept( |
547 &mut self.server_state, |
186 &mut self.server_state, |
548 client_id, |
187 client_id, |
549 &mut response, |
188 &mut response, |
550 addr.octets(), |
189 addr.octets(), |
551 addr.is_loopback(), |
190 addr.is_loopback(), |
552 ); |
191 ) |
553 self.handle_response(response, poll); |
|
554 } else { |
192 } else { |
555 todo!("implement something") |
193 todo!("implement something") |
556 } |
194 }; |
557 } |
195 |
558 |
196 self.handle_response(response).await; |
559 pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> { |
197 |
560 match server_token { |
198 if added { |
561 utils::SERVER_TOKEN => { |
199 Some(client) |
562 let (client_socket, addr) = self.listener.accept()?; |
|
563 info!("Connected(plaintext): {}", addr); |
|
564 let client_id = |
|
565 self.register_client(poll, self.create_client_socket(client_socket)?, addr)?; |
|
566 self.init_client(poll, client_id); |
|
567 } |
|
568 #[cfg(feature = "tls-connections")] |
|
569 utils::SECURE_SERVER_TOKEN => { |
|
570 let (client_socket, addr) = self.ssl.listener.accept()?; |
|
571 info!("Connected(TLS): {}", addr); |
|
572 self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr)?; |
|
573 } |
|
574 _ => unreachable!(), |
|
575 } |
|
576 |
|
577 Ok(()) |
|
578 } |
|
579 |
|
580 fn operation_failed( |
|
581 &mut self, |
|
582 poll: &Poll, |
|
583 client_id: ClientId, |
|
584 error: &Error, |
|
585 msg: &str, |
|
586 ) -> io::Result<()> { |
|
587 let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
|
588 client.peer_addr |
|
589 } else { |
200 } else { |
590 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) |
201 None |
591 }; |
202 } |
592 debug!("{}({}): {}", msg, addr, error); |
203 } |
593 self.client_error(poll, client_id) |
204 |
594 } |
205 async fn handle_message(&mut self, client_id: ClientId, message: HwProtocolMessage) { |
595 |
206 debug!("Handling message {:?} for client {}", message, client_id); |
596 pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
|
597 let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
|
598 client.read() |
|
599 } else { |
|
600 warn!("invalid readable client: {}", client_id); |
|
601 Ok((Vec::new(), NetworkClientState::Idle)) |
|
602 }; |
|
603 |
|
604 let mut response = handlers::Response::new(client_id); |
207 let mut response = handlers::Response::new(client_id); |
605 |
208 handlers::handle(&mut self.server_state, client_id, &mut response, message); |
606 match messages { |
209 self.handle_response(response).await; |
607 Ok((messages, state)) => { |
210 } |
608 for message in messages { |
211 |
609 debug!("Handling message {:?} for client {}", message, client_id); |
212 async fn handle_response(&mut self, mut response: handlers::Response) { |
610 handlers::handle(&mut self.server_state, client_id, &mut response, message); |
213 if response.is_empty() { |
611 } |
214 return; |
612 match state { |
215 } |
613 NetworkClientState::NeedsRead => { |
216 |
614 self.pending.insert((client_id, state)); |
217 debug!("{} pending server messages", response.len()); |
615 } |
218 let output = response.extract_messages(&mut self.server_state.server); |
616 NetworkClientState::Closed => self.client_error(&poll, client_id)?, |
219 for (clients, message) in output { |
617 #[cfg(feature = "tls-connections")] |
220 debug!("Message {:?} to {:?}", message, clients); |
618 NetworkClientState::Connected => self.init_client(poll, client_id), |
221 Self::send_message(&mut self.clients, message, clients.iter().cloned()).await; |
619 _ => {} |
222 } |
620 }; |
223 |
621 } |
224 for client_id in response.extract_removed_clients() { |
622 Err(e) => self.operation_failed( |
225 if self.clients.contains(client_id) { |
623 poll, |
226 self.clients.remove(client_id); |
624 client_id, |
227 } |
625 &e, |
228 info!("Client {} removed", client_id); |
626 "Error while reading from client socket", |
229 } |
627 )?, |
230 } |
628 } |
231 |
629 |
232 async fn send_message<I>( |
630 self.handle_response(response, poll); |
233 clients: &mut Slab<Sender<Bytes>>, |
631 |
234 message: HwServerMessage, |
632 Ok(()) |
235 to_clients: I, |
633 } |
236 ) where |
634 |
237 I: Iterator<Item = ClientId>, |
635 pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
238 { |
636 let result = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
239 let msg_string = message.to_raw_protocol(); |
637 client.write() |
240 let bytes = Bytes::copy_from_slice(msg_string.as_bytes()); |
638 } else { |
241 for client_id in to_clients { |
639 warn!("invalid writable client: {}", client_id); |
242 if let Some(client) = clients.get_mut(client_id) { |
640 Ok(((), NetworkClientState::Idle)) |
243 if !client.send(bytes.clone()).await.is_ok() { |
641 }; |
244 clients.remove(client_id); |
642 |
245 } |
643 match result { |
246 } |
644 Ok(((), state)) if state == NetworkClientState::NeedsWrite => { |
247 } |
645 self.pending.insert((client_id, state)); |
|
646 } |
|
647 Ok(((), state)) if state == NetworkClientState::Closed => { |
|
648 self.deregister_client(poll, client_id, false); |
|
649 } |
|
650 Ok(_) => (), |
|
651 Err(e) => { |
|
652 self.operation_failed(poll, client_id, &e, "Error while writing to client socket")? |
|
653 } |
|
654 } |
|
655 |
|
656 Ok(()) |
|
657 } |
|
658 |
|
659 pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
|
660 let pending_close = self.clients[client_id].pending_close; |
|
661 self.deregister_client(poll, client_id, true); |
|
662 |
|
663 if !pending_close { |
|
664 let mut response = handlers::Response::new(client_id); |
|
665 handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); |
|
666 self.handle_response(response, poll); |
|
667 } |
|
668 |
|
669 Ok(()) |
|
670 } |
|
671 |
|
672 pub fn has_pending_operations(&self) -> bool { |
|
673 !self.pending.is_empty() || !self.timeout_events.is_empty() |
|
674 } |
|
675 |
|
676 pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> { |
|
677 if self.has_pending_operations() { |
|
678 let mut cache = replace(&mut self.pending_cache, Vec::new()); |
|
679 cache.extend(self.pending.drain()); |
|
680 for (id, state) in cache.drain(..) { |
|
681 match state { |
|
682 NetworkClientState::NeedsRead => self.client_readable(poll, id)?, |
|
683 NetworkClientState::NeedsWrite => self.client_writable(poll, id)?, |
|
684 _ => {} |
|
685 } |
|
686 } |
|
687 swap(&mut cache, &mut self.pending_cache); |
|
688 } |
|
689 Ok(()) |
|
690 } |
248 } |
691 } |
249 } |
692 |
250 |
693 pub struct NetworkLayerBuilder { |
251 pub struct NetworkLayerBuilder { |
694 listener: Option<TcpListener>, |
252 listener: Option<TcpListener>, |
695 secure_listener: Option<TcpListener>, |
|
696 clients_capacity: usize, |
253 clients_capacity: usize, |
697 rooms_capacity: usize, |
254 rooms_capacity: usize, |
698 } |
255 } |
699 |
256 |
700 impl Default for NetworkLayerBuilder { |
257 impl Default for NetworkLayerBuilder { |
701 fn default() -> Self { |
258 fn default() -> Self { |
702 Self { |
259 Self { |
703 clients_capacity: 1024, |
260 clients_capacity: 1024, |
704 rooms_capacity: 512, |
261 rooms_capacity: 512, |
705 listener: None, |
262 listener: None, |
706 secure_listener: None, |
|
707 } |
263 } |
708 } |
264 } |
709 } |
265 } |
710 |
266 |
711 impl NetworkLayerBuilder { |
267 impl NetworkLayerBuilder { |