|
1 extern crate slab; |
|
2 |
|
3 use std::{ |
|
4 io, io::{Error, ErrorKind, Read, Write}, |
|
5 net::{SocketAddr, IpAddr, Ipv4Addr}, |
|
6 collections::HashSet, |
|
7 mem::{swap, replace} |
|
8 }; |
|
9 |
|
10 use mio::{ |
|
11 net::{TcpStream, TcpListener}, |
|
12 Poll, PollOpt, Ready, Token |
|
13 }; |
|
14 use netbuf; |
|
15 use slab::Slab; |
|
16 use log::*; |
|
17 |
|
18 use crate::{ |
|
19 utils, |
|
20 protocol::{ProtocolDecoder, messages::*} |
|
21 }; |
|
22 use super::{ |
|
23 io::FileServerIO, |
|
24 core::{HWServer}, |
|
25 coretypes::ClientId |
|
26 }; |
|
27 #[cfg(feature = "tls-connections")] |
|
28 use openssl::{ |
|
29 ssl::{ |
|
30 SslMethod, SslContext, Ssl, SslContextBuilder, |
|
31 SslVerifyMode, SslFiletype, SslOptions, |
|
32 SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream |
|
33 }, |
|
34 error::ErrorStack |
|
35 }; |
|
36 |
|
37 const MAX_BYTES_PER_READ: usize = 2048; |
|
38 |
|
39 #[derive(Hash, Eq, PartialEq, Copy, Clone)] |
|
40 pub enum NetworkClientState { |
|
41 Idle, |
|
42 NeedsWrite, |
|
43 NeedsRead, |
|
44 Closed, |
|
45 } |
|
46 |
|
47 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
|
48 |
|
49 #[cfg(not(feature = "tls-connections"))] |
|
50 pub enum ClientSocket { |
|
51 Plain(TcpStream) |
|
52 } |
|
53 |
|
54 #[cfg(feature = "tls-connections")] |
|
55 pub enum ClientSocket { |
|
56 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
|
57 SslStream(SslStream<TcpStream>) |
|
58 } |
|
59 |
|
60 impl ClientSocket { |
|
61 fn inner(&self) -> &TcpStream { |
|
62 #[cfg(not(feature = "tls-connections"))] |
|
63 match self { |
|
64 ClientSocket::Plain(stream) => stream, |
|
65 } |
|
66 |
|
67 #[cfg(feature = "tls-connections")] |
|
68 match self { |
|
69 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
|
70 ClientSocket::SslHandshake(None) => unreachable!(), |
|
71 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() |
|
72 } |
|
73 } |
|
74 } |
|
75 |
|
76 pub struct NetworkClient { |
|
77 id: ClientId, |
|
78 socket: ClientSocket, |
|
79 peer_addr: SocketAddr, |
|
80 decoder: ProtocolDecoder, |
|
81 buf_out: netbuf::Buf |
|
82 } |
|
83 |
|
84 impl NetworkClient { |
|
85 pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { |
|
86 NetworkClient { |
|
87 id, socket, peer_addr, |
|
88 decoder: ProtocolDecoder::new(), |
|
89 buf_out: netbuf::Buf::new() |
|
90 } |
|
91 } |
|
92 |
|
93 #[cfg(feature = "tls-connections")] |
|
94 fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> { |
|
95 match handshake.handshake() { |
|
96 Ok(stream) => { |
|
97 self.socket = ClientSocket::SslStream(stream); |
|
98 debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); |
|
99 Ok(NetworkClientState::Idle) |
|
100 } |
|
101 Err(HandshakeError::WouldBlock(new_handshake)) => { |
|
102 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
|
103 Ok(NetworkClientState::Idle) |
|
104 } |
|
105 Err(HandshakeError::Failure(new_handshake)) => { |
|
106 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
|
107 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); |
|
108 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
109 } |
|
110 Err(HandshakeError::SetupFailure(_)) => unreachable!() |
|
111 } |
|
112 } |
|
113 |
|
114 fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R, |
|
115 id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> { |
|
116 let mut bytes_read = 0; |
|
117 let result = loop { |
|
118 match decoder.read_from(source) { |
|
119 Ok(bytes) => { |
|
120 debug!("Client {}: read {} bytes", id, bytes); |
|
121 bytes_read += bytes; |
|
122 if bytes == 0 { |
|
123 let result = if bytes_read == 0 { |
|
124 info!("EOF for client {} ({})", id, addr); |
|
125 (Vec::new(), NetworkClientState::Closed) |
|
126 } else { |
|
127 (decoder.extract_messages(), NetworkClientState::NeedsRead) |
|
128 }; |
|
129 break Ok(result); |
|
130 } |
|
131 else if bytes_read >= MAX_BYTES_PER_READ { |
|
132 break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) |
|
133 } |
|
134 } |
|
135 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
|
136 let messages = if bytes_read == 0 { |
|
137 Vec::new() |
|
138 } else { |
|
139 decoder.extract_messages() |
|
140 }; |
|
141 break Ok((messages, NetworkClientState::Idle)); |
|
142 } |
|
143 Err(error) => |
|
144 break Err(error) |
|
145 } |
|
146 }; |
|
147 decoder.sweep(); |
|
148 result |
|
149 } |
|
150 |
|
151 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
|
152 #[cfg(not(feature = "tls-connections"))] |
|
153 match self.socket { |
|
154 ClientSocket::Plain(ref mut stream) => |
|
155 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), |
|
156 } |
|
157 |
|
158 #[cfg(feature = "tls-connections")] |
|
159 match self.socket { |
|
160 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
|
161 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
|
162 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
|
163 }, |
|
164 ClientSocket::SslStream(ref mut stream) => |
|
165 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
|
166 } |
|
167 } |
|
168 |
|
169 fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { |
|
170 let result = loop { |
|
171 match buf_out.write_to(destination) { |
|
172 Ok(bytes) if buf_out.is_empty() || bytes == 0 => |
|
173 break Ok(((), NetworkClientState::Idle)), |
|
174 Ok(_) => (), |
|
175 Err(ref error) if error.kind() == ErrorKind::Interrupted |
|
176 || error.kind() == ErrorKind::WouldBlock => { |
|
177 break Ok(((), NetworkClientState::NeedsWrite)); |
|
178 }, |
|
179 Err(error) => |
|
180 break Err(error) |
|
181 } |
|
182 }; |
|
183 result |
|
184 } |
|
185 |
|
186 pub fn write(&mut self) -> NetworkResult<()> { |
|
187 let result = { |
|
188 #[cfg(not(feature = "tls-connections"))] |
|
189 match self.socket { |
|
190 ClientSocket::Plain(ref mut stream) => |
|
191 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
192 } |
|
193 |
|
194 #[cfg(feature = "tls-connections")] { |
|
195 match self.socket { |
|
196 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
|
197 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
|
198 Ok(((), self.handshake_impl(handshake)?)) |
|
199 } |
|
200 ClientSocket::SslStream(ref mut stream) => |
|
201 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
202 } |
|
203 } |
|
204 }; |
|
205 |
|
206 self.socket.inner().flush()?; |
|
207 result |
|
208 } |
|
209 |
|
210 pub fn send_raw_msg(&mut self, msg: &[u8]) { |
|
211 self.buf_out.write_all(msg).unwrap(); |
|
212 } |
|
213 |
|
214 pub fn send_string(&mut self, msg: &str) { |
|
215 self.send_raw_msg(&msg.as_bytes()); |
|
216 } |
|
217 |
|
218 pub fn send_msg(&mut self, msg: &HWServerMessage) { |
|
219 self.send_string(&msg.to_raw_protocol()); |
|
220 } |
|
221 } |
|
222 |
|
223 #[cfg(feature = "tls-connections")] |
|
224 struct ServerSsl { |
|
225 context: SslContext |
|
226 } |
|
227 |
|
228 pub struct NetworkLayer { |
|
229 listener: TcpListener, |
|
230 server: HWServer, |
|
231 clients: Slab<NetworkClient>, |
|
232 pending: HashSet<(ClientId, NetworkClientState)>, |
|
233 pending_cache: Vec<(ClientId, NetworkClientState)>, |
|
234 #[cfg(feature = "tls-connections")] |
|
235 ssl: ServerSsl |
|
236 } |
|
237 |
|
238 impl NetworkLayer { |
|
239 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
|
240 let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new())); |
|
241 let clients = Slab::with_capacity(clients_limit); |
|
242 let pending = HashSet::with_capacity(2 * clients_limit); |
|
243 let pending_cache = Vec::with_capacity(2 * clients_limit); |
|
244 |
|
245 NetworkLayer { |
|
246 listener, server, clients, pending, pending_cache, |
|
247 #[cfg(feature = "tls-connections")] |
|
248 ssl: NetworkLayer::create_ssl_context() |
|
249 } |
|
250 } |
|
251 |
|
252 #[cfg(feature = "tls-connections")] |
|
253 fn create_ssl_context() -> ServerSsl { |
|
254 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
|
255 builder.set_verify(SslVerifyMode::NONE); |
|
256 builder.set_read_ahead(true); |
|
257 builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); |
|
258 builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); |
|
259 builder.set_options(SslOptions::NO_COMPRESSION); |
|
260 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
|
261 ServerSsl { context: builder.build() } |
|
262 } |
|
263 |
|
264 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
|
265 poll.register(&self.listener, utils::SERVER, Ready::readable(), |
|
266 PollOpt::edge()) |
|
267 } |
|
268 |
|
269 fn deregister_client(&mut self, poll: &Poll, id: ClientId) { |
|
270 let mut client_exists = false; |
|
271 if let Some(ref client) = self.clients.get(id) { |
|
272 poll.deregister(client.socket.inner()) |
|
273 .expect("could not deregister socket"); |
|
274 info!("client {} ({}) removed", client.id, client.peer_addr); |
|
275 client_exists = true; |
|
276 } |
|
277 if client_exists { |
|
278 self.clients.remove(id); |
|
279 } |
|
280 } |
|
281 |
|
282 fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { |
|
283 poll.register(client_socket.inner(), Token(id), |
|
284 Ready::readable() | Ready::writable(), |
|
285 PollOpt::edge()) |
|
286 .expect("could not register socket with event loop"); |
|
287 |
|
288 let entry = self.clients.vacant_entry(); |
|
289 let client = NetworkClient::new(id, client_socket, addr); |
|
290 info!("client {} ({}) added", client.id, client.peer_addr); |
|
291 entry.insert(client); |
|
292 } |
|
293 |
|
294 fn flush_server_messages(&mut self) { |
|
295 debug!("{} pending server messages", self.server.output.len()); |
|
296 for (clients, message) in self.server.output.drain(..) { |
|
297 debug!("Message {:?} to {:?}", message, clients); |
|
298 let msg_string = message.to_raw_protocol(); |
|
299 for client_id in clients { |
|
300 if let Some(client) = self.clients.get_mut(client_id) { |
|
301 client.send_string(&msg_string); |
|
302 self.pending.insert((client_id, NetworkClientState::NeedsWrite)); |
|
303 } |
|
304 } |
|
305 } |
|
306 } |
|
307 |
|
308 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
|
309 #[cfg(not(feature = "tls-connections"))] { |
|
310 Ok(ClientSocket::Plain(socket)) |
|
311 } |
|
312 |
|
313 #[cfg(feature = "tls-connections")] { |
|
314 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
|
315 let mut builder = SslStreamBuilder::new(ssl, socket); |
|
316 builder.set_accept_state(); |
|
317 match builder.handshake() { |
|
318 Ok(stream) => |
|
319 Ok(ClientSocket::SslStream(stream)), |
|
320 Err(HandshakeError::WouldBlock(stream)) => |
|
321 Ok(ClientSocket::SslHandshake(Some(stream))), |
|
322 Err(e) => { |
|
323 debug!("OpenSSL handshake failed: {}", e); |
|
324 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
325 } |
|
326 } |
|
327 } |
|
328 } |
|
329 |
|
330 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
|
331 let (client_socket, addr) = self.listener.accept()?; |
|
332 info!("Connected: {}", addr); |
|
333 |
|
334 let client_id = self.server.add_client(); |
|
335 self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); |
|
336 self.flush_server_messages(); |
|
337 |
|
338 Ok(()) |
|
339 } |
|
340 |
|
341 fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> { |
|
342 let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
|
343 client.peer_addr |
|
344 } else { |
|
345 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) |
|
346 }; |
|
347 debug!("{}({}): {}", msg, addr, error); |
|
348 self.client_error(poll, client_id) |
|
349 } |
|
350 |
|
351 pub fn client_readable(&mut self, poll: &Poll, |
|
352 client_id: ClientId) -> io::Result<()> { |
|
353 let messages = |
|
354 if let Some(ref mut client) = self.clients.get_mut(client_id) { |
|
355 client.read() |
|
356 } else { |
|
357 warn!("invalid readable client: {}", client_id); |
|
358 Ok((Vec::new(), NetworkClientState::Idle)) |
|
359 }; |
|
360 |
|
361 match messages { |
|
362 Ok((messages, state)) => { |
|
363 for message in messages { |
|
364 self.server.handle_msg(client_id, message); |
|
365 } |
|
366 match state { |
|
367 NetworkClientState::NeedsRead => { |
|
368 self.pending.insert((client_id, state)); |
|
369 }, |
|
370 NetworkClientState::Closed => |
|
371 self.client_error(&poll, client_id)?, |
|
372 _ => {} |
|
373 }; |
|
374 } |
|
375 Err(e) => self.operation_failed( |
|
376 poll, client_id, &e, |
|
377 "Error while reading from client socket")? |
|
378 } |
|
379 |
|
380 self.flush_server_messages(); |
|
381 |
|
382 if !self.server.removed_clients.is_empty() { |
|
383 let ids: Vec<_> = self.server.removed_clients.drain(..).collect(); |
|
384 for client_id in ids { |
|
385 self.deregister_client(poll, client_id); |
|
386 } |
|
387 } |
|
388 |
|
389 Ok(()) |
|
390 } |
|
391 |
|
392 pub fn client_writable(&mut self, poll: &Poll, |
|
393 client_id: ClientId) -> io::Result<()> { |
|
394 let result = |
|
395 if let Some(ref mut client) = self.clients.get_mut(client_id) { |
|
396 client.write() |
|
397 } else { |
|
398 warn!("invalid writable client: {}", client_id); |
|
399 Ok(((), NetworkClientState::Idle)) |
|
400 }; |
|
401 |
|
402 match result { |
|
403 Ok(((), state)) if state == NetworkClientState::NeedsWrite => { |
|
404 self.pending.insert((client_id, state)); |
|
405 }, |
|
406 Ok(_) => {} |
|
407 Err(e) => self.operation_failed( |
|
408 poll, client_id, &e, |
|
409 "Error while writing to client socket")? |
|
410 } |
|
411 |
|
412 Ok(()) |
|
413 } |
|
414 |
|
415 pub fn client_error(&mut self, poll: &Poll, |
|
416 client_id: ClientId) -> io::Result<()> { |
|
417 self.deregister_client(poll, client_id); |
|
418 self.server.client_lost(client_id); |
|
419 |
|
420 Ok(()) |
|
421 } |
|
422 |
|
423 pub fn has_pending_operations(&self) -> bool { |
|
424 !self.pending.is_empty() |
|
425 } |
|
426 |
|
427 pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> { |
|
428 if self.has_pending_operations() { |
|
429 let mut cache = replace(&mut self.pending_cache, Vec::new()); |
|
430 cache.extend(self.pending.drain()); |
|
431 for (id, state) in cache.drain(..) { |
|
432 match state { |
|
433 NetworkClientState::NeedsRead => |
|
434 self.client_readable(poll, id)?, |
|
435 NetworkClientState::NeedsWrite => |
|
436 self.client_writable(poll, id)?, |
|
437 _ => {} |
|
438 } |
|
439 } |
|
440 swap(&mut cache, &mut self.pending_cache); |
|
441 } |
|
442 Ok(()) |
|
443 } |
|
444 } |