1 extern crate slab; |
1 extern crate slab; |
2 |
2 |
3 use std::{ |
3 use std::{ |
4 io, io::{Error, ErrorKind, Read, Write}, |
|
5 net::{SocketAddr, IpAddr, Ipv4Addr}, |
|
6 collections::HashSet, |
4 collections::HashSet, |
7 mem::{swap, replace} |
5 io, |
|
6 io::{Error, ErrorKind, Read, Write}, |
|
7 mem::{replace, swap}, |
|
8 net::{IpAddr, Ipv4Addr, SocketAddr}, |
8 }; |
9 }; |
9 |
10 |
|
11 use log::*; |
10 use mio::{ |
12 use mio::{ |
11 net::{TcpStream, TcpListener}, |
13 net::{TcpListener, TcpStream}, |
12 Poll, PollOpt, Ready, Token |
14 Poll, PollOpt, Ready, Token, |
13 }; |
15 }; |
14 use netbuf; |
16 use netbuf; |
15 use slab::Slab; |
17 use slab::Slab; |
16 use log::*; |
18 |
17 |
19 use super::{core::HWServer, coretypes::ClientId, io::FileServerIO}; |
18 use crate::{ |
20 use crate::{ |
|
21 protocol::{messages::*, ProtocolDecoder}, |
19 utils, |
22 utils, |
20 protocol::{ProtocolDecoder, messages::*} |
|
21 }; |
|
22 use super::{ |
|
23 io::FileServerIO, |
|
24 core::{HWServer}, |
|
25 coretypes::ClientId |
|
26 }; |
23 }; |
27 #[cfg(feature = "tls-connections")] |
24 #[cfg(feature = "tls-connections")] |
28 use openssl::{ |
25 use openssl::{ |
|
26 error::ErrorStack, |
29 ssl::{ |
27 ssl::{ |
30 SslMethod, SslContext, Ssl, SslContextBuilder, |
28 HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, |
31 SslVerifyMode, SslFiletype, SslOptions, |
29 SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, |
32 SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream |
|
33 }, |
30 }, |
34 error::ErrorStack |
|
35 }; |
31 }; |
36 |
32 |
37 const MAX_BYTES_PER_READ: usize = 2048; |
33 const MAX_BYTES_PER_READ: usize = 2048; |
38 |
34 |
39 #[derive(Hash, Eq, PartialEq, Copy, Clone)] |
35 #[derive(Hash, Eq, PartialEq, Copy, Clone)] |
46 |
42 |
47 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
43 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
48 |
44 |
49 #[cfg(not(feature = "tls-connections"))] |
45 #[cfg(not(feature = "tls-connections"))] |
50 pub enum ClientSocket { |
46 pub enum ClientSocket { |
51 Plain(TcpStream) |
47 Plain(TcpStream), |
52 } |
48 } |
53 |
49 |
54 #[cfg(feature = "tls-connections")] |
50 #[cfg(feature = "tls-connections")] |
55 pub enum ClientSocket { |
51 pub enum ClientSocket { |
56 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
52 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
57 SslStream(SslStream<TcpStream>) |
53 SslStream(SslStream<TcpStream>), |
58 } |
54 } |
59 |
55 |
60 impl ClientSocket { |
56 impl ClientSocket { |
61 fn inner(&self) -> &TcpStream { |
57 fn inner(&self) -> &TcpStream { |
62 #[cfg(not(feature = "tls-connections"))] |
58 #[cfg(not(feature = "tls-connections"))] |
66 |
62 |
67 #[cfg(feature = "tls-connections")] |
63 #[cfg(feature = "tls-connections")] |
68 match self { |
64 match self { |
69 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
65 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
70 ClientSocket::SslHandshake(None) => unreachable!(), |
66 ClientSocket::SslHandshake(None) => unreachable!(), |
71 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() |
67 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), |
72 } |
68 } |
73 } |
69 } |
74 } |
70 } |
75 |
71 |
76 pub struct NetworkClient { |
72 pub struct NetworkClient { |
77 id: ClientId, |
73 id: ClientId, |
78 socket: ClientSocket, |
74 socket: ClientSocket, |
79 peer_addr: SocketAddr, |
75 peer_addr: SocketAddr, |
80 decoder: ProtocolDecoder, |
76 decoder: ProtocolDecoder, |
81 buf_out: netbuf::Buf |
77 buf_out: netbuf::Buf, |
82 } |
78 } |
83 |
79 |
84 impl NetworkClient { |
80 impl NetworkClient { |
85 pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { |
81 pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { |
86 NetworkClient { |
82 NetworkClient { |
87 id, socket, peer_addr, |
83 id, |
|
84 socket, |
|
85 peer_addr, |
88 decoder: ProtocolDecoder::new(), |
86 decoder: ProtocolDecoder::new(), |
89 buf_out: netbuf::Buf::new() |
87 buf_out: netbuf::Buf::new(), |
90 } |
88 } |
91 } |
89 } |
92 |
90 |
93 #[cfg(feature = "tls-connections")] |
91 #[cfg(feature = "tls-connections")] |
94 fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> { |
92 fn handshake_impl( |
|
93 &mut self, |
|
94 handshake: MidHandshakeSslStream<TcpStream>, |
|
95 ) -> io::Result<NetworkClientState> { |
95 match handshake.handshake() { |
96 match handshake.handshake() { |
96 Ok(stream) => { |
97 Ok(stream) => { |
97 self.socket = ClientSocket::SslStream(stream); |
98 self.socket = ClientSocket::SslStream(stream); |
98 debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); |
99 debug!( |
|
100 "TLS handshake with {} ({}) completed", |
|
101 self.id, self.peer_addr |
|
102 ); |
99 Ok(NetworkClientState::Idle) |
103 Ok(NetworkClientState::Idle) |
100 } |
104 } |
101 Err(HandshakeError::WouldBlock(new_handshake)) => { |
105 Err(HandshakeError::WouldBlock(new_handshake)) => { |
102 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
106 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
103 Ok(NetworkClientState::Idle) |
107 Ok(NetworkClientState::Idle) |
105 Err(HandshakeError::Failure(new_handshake)) => { |
109 Err(HandshakeError::Failure(new_handshake)) => { |
106 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
110 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
107 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); |
111 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); |
108 Err(Error::new(ErrorKind::Other, "Connection failure")) |
112 Err(Error::new(ErrorKind::Other, "Connection failure")) |
109 } |
113 } |
110 Err(HandshakeError::SetupFailure(_)) => unreachable!() |
114 Err(HandshakeError::SetupFailure(_)) => unreachable!(), |
111 } |
115 } |
112 } |
116 } |
113 |
117 |
114 fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R, |
118 fn read_impl<R: Read>( |
115 id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> { |
119 decoder: &mut ProtocolDecoder, |
|
120 source: &mut R, |
|
121 id: ClientId, |
|
122 addr: &SocketAddr, |
|
123 ) -> NetworkResult<Vec<HWProtocolMessage>> { |
116 let mut bytes_read = 0; |
124 let mut bytes_read = 0; |
117 let result = loop { |
125 let result = loop { |
118 match decoder.read_from(source) { |
126 match decoder.read_from(source) { |
119 Ok(bytes) => { |
127 Ok(bytes) => { |
120 debug!("Client {}: read {} bytes", id, bytes); |
128 debug!("Client {}: read {} bytes", id, bytes); |
125 (Vec::new(), NetworkClientState::Closed) |
133 (Vec::new(), NetworkClientState::Closed) |
126 } else { |
134 } else { |
127 (decoder.extract_messages(), NetworkClientState::NeedsRead) |
135 (decoder.extract_messages(), NetworkClientState::NeedsRead) |
128 }; |
136 }; |
129 break Ok(result); |
137 break Ok(result); |
|
138 } else if bytes_read >= MAX_BYTES_PER_READ { |
|
139 break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)); |
130 } |
140 } |
131 else if bytes_read >= MAX_BYTES_PER_READ { |
|
132 break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) |
|
133 } |
|
134 } |
141 } |
135 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
142 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
136 let messages = if bytes_read == 0 { |
143 let messages = if bytes_read == 0 { |
137 Vec::new() |
144 Vec::new() |
138 } else { |
145 } else { |
139 decoder.extract_messages() |
146 decoder.extract_messages() |
140 }; |
147 }; |
141 break Ok((messages, NetworkClientState::Idle)); |
148 break Ok((messages, NetworkClientState::Idle)); |
142 } |
149 } |
143 Err(error) => |
150 Err(error) => break Err(error), |
144 break Err(error) |
|
145 } |
151 } |
146 }; |
152 }; |
147 decoder.sweep(); |
153 decoder.sweep(); |
148 result |
154 result |
149 } |
155 } |
150 |
156 |
151 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
157 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
152 #[cfg(not(feature = "tls-connections"))] |
158 #[cfg(not(feature = "tls-connections"))] |
153 match self.socket { |
159 match self.socket { |
154 ClientSocket::Plain(ref mut stream) => |
160 ClientSocket::Plain(ref mut stream) => { |
155 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), |
161 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
|
162 } |
156 } |
163 } |
157 |
164 |
158 #[cfg(feature = "tls-connections")] |
165 #[cfg(feature = "tls-connections")] |
159 match self.socket { |
166 match self.socket { |
160 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
167 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
161 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
168 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
162 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
169 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
163 }, |
170 } |
164 ClientSocket::SslStream(ref mut stream) => |
171 ClientSocket::SslStream(ref mut stream) => { |
165 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
172 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
|
173 } |
166 } |
174 } |
167 } |
175 } |
168 |
176 |
169 fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { |
177 fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { |
170 let result = loop { |
178 let result = loop { |
171 match buf_out.write_to(destination) { |
179 match buf_out.write_to(destination) { |
172 Ok(bytes) if buf_out.is_empty() || bytes == 0 => |
180 Ok(bytes) if buf_out.is_empty() || bytes == 0 => { |
173 break Ok(((), NetworkClientState::Idle)), |
181 break Ok(((), NetworkClientState::Idle)) |
|
182 } |
174 Ok(_) => (), |
183 Ok(_) => (), |
175 Err(ref error) if error.kind() == ErrorKind::Interrupted |
184 Err(ref error) |
176 || error.kind() == ErrorKind::WouldBlock => { |
185 if error.kind() == ErrorKind::Interrupted |
|
186 || error.kind() == ErrorKind::WouldBlock => |
|
187 { |
177 break Ok(((), NetworkClientState::NeedsWrite)); |
188 break Ok(((), NetworkClientState::NeedsWrite)); |
178 }, |
189 } |
179 Err(error) => |
190 Err(error) => break Err(error), |
180 break Err(error) |
|
181 } |
191 } |
182 }; |
192 }; |
183 result |
193 result |
184 } |
194 } |
185 |
195 |
186 pub fn write(&mut self) -> NetworkResult<()> { |
196 pub fn write(&mut self) -> NetworkResult<()> { |
187 let result = { |
197 let result = { |
188 #[cfg(not(feature = "tls-connections"))] |
198 #[cfg(not(feature = "tls-connections"))] |
189 match self.socket { |
199 match self.socket { |
190 ClientSocket::Plain(ref mut stream) => |
200 ClientSocket::Plain(ref mut stream) => { |
191 NetworkClient::write_impl(&mut self.buf_out, stream) |
201 NetworkClient::write_impl(&mut self.buf_out, stream) |
192 } |
202 } |
193 |
203 } |
194 #[cfg(feature = "tls-connections")] { |
204 |
|
205 #[cfg(feature = "tls-connections")] |
|
206 { |
195 match self.socket { |
207 match self.socket { |
196 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
208 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
197 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
209 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
198 Ok(((), self.handshake_impl(handshake)?)) |
210 Ok(((), self.handshake_impl(handshake)?)) |
199 } |
211 } |
200 ClientSocket::SslStream(ref mut stream) => |
212 ClientSocket::SslStream(ref mut stream) => { |
201 NetworkClient::write_impl(&mut self.buf_out, stream) |
213 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
214 } |
202 } |
215 } |
203 } |
216 } |
204 }; |
217 }; |
205 |
218 |
206 self.socket.inner().flush()?; |
219 self.socket.inner().flush()?; |
220 } |
233 } |
221 } |
234 } |
222 |
235 |
223 #[cfg(feature = "tls-connections")] |
236 #[cfg(feature = "tls-connections")] |
224 struct ServerSsl { |
237 struct ServerSsl { |
225 context: SslContext |
238 context: SslContext, |
226 } |
239 } |
227 |
240 |
228 pub struct NetworkLayer { |
241 pub struct NetworkLayer { |
229 listener: TcpListener, |
242 listener: TcpListener, |
230 server: HWServer, |
243 server: HWServer, |
231 clients: Slab<NetworkClient>, |
244 clients: Slab<NetworkClient>, |
232 pending: HashSet<(ClientId, NetworkClientState)>, |
245 pending: HashSet<(ClientId, NetworkClientState)>, |
233 pending_cache: Vec<(ClientId, NetworkClientState)>, |
246 pending_cache: Vec<(ClientId, NetworkClientState)>, |
234 #[cfg(feature = "tls-connections")] |
247 #[cfg(feature = "tls-connections")] |
235 ssl: ServerSsl |
248 ssl: ServerSsl, |
236 } |
249 } |
237 |
250 |
238 impl NetworkLayer { |
251 impl NetworkLayer { |
239 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
252 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
240 let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new())); |
253 let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new())); |
241 let clients = Slab::with_capacity(clients_limit); |
254 let clients = Slab::with_capacity(clients_limit); |
242 let pending = HashSet::with_capacity(2 * clients_limit); |
255 let pending = HashSet::with_capacity(2 * clients_limit); |
243 let pending_cache = Vec::with_capacity(2 * clients_limit); |
256 let pending_cache = Vec::with_capacity(2 * clients_limit); |
244 |
257 |
245 NetworkLayer { |
258 NetworkLayer { |
246 listener, server, clients, pending, pending_cache, |
259 listener, |
|
260 server, |
|
261 clients, |
|
262 pending, |
|
263 pending_cache, |
247 #[cfg(feature = "tls-connections")] |
264 #[cfg(feature = "tls-connections")] |
248 ssl: NetworkLayer::create_ssl_context() |
265 ssl: NetworkLayer::create_ssl_context(), |
249 } |
266 } |
250 } |
267 } |
251 |
268 |
252 #[cfg(feature = "tls-connections")] |
269 #[cfg(feature = "tls-connections")] |
253 fn create_ssl_context() -> ServerSsl { |
270 fn create_ssl_context() -> ServerSsl { |
254 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
271 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
255 builder.set_verify(SslVerifyMode::NONE); |
272 builder.set_verify(SslVerifyMode::NONE); |
256 builder.set_read_ahead(true); |
273 builder.set_read_ahead(true); |
257 builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); |
274 builder |
258 builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); |
275 .set_certificate_file("ssl/cert.pem", SslFiletype::PEM) |
|
276 .unwrap(); |
|
277 builder |
|
278 .set_private_key_file("ssl/key.pem", SslFiletype::PEM) |
|
279 .unwrap(); |
259 builder.set_options(SslOptions::NO_COMPRESSION); |
280 builder.set_options(SslOptions::NO_COMPRESSION); |
260 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
281 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
261 ServerSsl { context: builder.build() } |
282 ServerSsl { |
|
283 context: builder.build(), |
|
284 } |
262 } |
285 } |
263 |
286 |
264 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
287 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
265 poll.register(&self.listener, utils::SERVER, Ready::readable(), |
288 poll.register( |
266 PollOpt::edge()) |
289 &self.listener, |
|
290 utils::SERVER, |
|
291 Ready::readable(), |
|
292 PollOpt::edge(), |
|
293 ) |
267 } |
294 } |
268 |
295 |
269 fn deregister_client(&mut self, poll: &Poll, id: ClientId) { |
296 fn deregister_client(&mut self, poll: &Poll, id: ClientId) { |
270 let mut client_exists = false; |
297 let mut client_exists = false; |
271 if let Some(ref client) = self.clients.get(id) { |
298 if let Some(ref client) = self.clients.get(id) { |
277 if client_exists { |
304 if client_exists { |
278 self.clients.remove(id); |
305 self.clients.remove(id); |
279 } |
306 } |
280 } |
307 } |
281 |
308 |
282 fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { |
309 fn register_client( |
283 poll.register(client_socket.inner(), Token(id), |
310 &mut self, |
284 Ready::readable() | Ready::writable(), |
311 poll: &Poll, |
285 PollOpt::edge()) |
312 id: ClientId, |
286 .expect("could not register socket with event loop"); |
313 client_socket: ClientSocket, |
|
314 addr: SocketAddr, |
|
315 ) { |
|
316 poll.register( |
|
317 client_socket.inner(), |
|
318 Token(id), |
|
319 Ready::readable() | Ready::writable(), |
|
320 PollOpt::edge(), |
|
321 ) |
|
322 .expect("could not register socket with event loop"); |
287 |
323 |
288 let entry = self.clients.vacant_entry(); |
324 let entry = self.clients.vacant_entry(); |
289 let client = NetworkClient::new(id, client_socket, addr); |
325 let client = NetworkClient::new(id, client_socket, addr); |
290 info!("client {} ({}) added", client.id, client.peer_addr); |
326 info!("client {} ({}) added", client.id, client.peer_addr); |
291 entry.insert(client); |
327 entry.insert(client); |
297 debug!("Message {:?} to {:?}", message, clients); |
333 debug!("Message {:?} to {:?}", message, clients); |
298 let msg_string = message.to_raw_protocol(); |
334 let msg_string = message.to_raw_protocol(); |
299 for client_id in clients { |
335 for client_id in clients { |
300 if let Some(client) = self.clients.get_mut(client_id) { |
336 if let Some(client) = self.clients.get_mut(client_id) { |
301 client.send_string(&msg_string); |
337 client.send_string(&msg_string); |
302 self.pending.insert((client_id, NetworkClientState::NeedsWrite)); |
338 self.pending |
|
339 .insert((client_id, NetworkClientState::NeedsWrite)); |
303 } |
340 } |
304 } |
341 } |
305 } |
342 } |
306 } |
343 } |
307 |
344 |
308 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
345 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
309 #[cfg(not(feature = "tls-connections"))] { |
346 #[cfg(not(feature = "tls-connections"))] |
|
347 { |
310 Ok(ClientSocket::Plain(socket)) |
348 Ok(ClientSocket::Plain(socket)) |
311 } |
349 } |
312 |
350 |
313 #[cfg(feature = "tls-connections")] { |
351 #[cfg(feature = "tls-connections")] |
|
352 { |
314 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
353 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
315 let mut builder = SslStreamBuilder::new(ssl, socket); |
354 let mut builder = SslStreamBuilder::new(ssl, socket); |
316 builder.set_accept_state(); |
355 builder.set_accept_state(); |
317 match builder.handshake() { |
356 match builder.handshake() { |
318 Ok(stream) => |
357 Ok(stream) => Ok(ClientSocket::SslStream(stream)), |
319 Ok(ClientSocket::SslStream(stream)), |
358 Err(HandshakeError::WouldBlock(stream)) => { |
320 Err(HandshakeError::WouldBlock(stream)) => |
359 Ok(ClientSocket::SslHandshake(Some(stream))) |
321 Ok(ClientSocket::SslHandshake(Some(stream))), |
360 } |
322 Err(e) => { |
361 Err(e) => { |
323 debug!("OpenSSL handshake failed: {}", e); |
362 debug!("OpenSSL handshake failed: {}", e); |
324 Err(Error::new(ErrorKind::Other, "Connection failure")) |
363 Err(Error::new(ErrorKind::Other, "Connection failure")) |
325 } |
364 } |
326 } |
365 } |
330 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
369 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
331 let (client_socket, addr) = self.listener.accept()?; |
370 let (client_socket, addr) = self.listener.accept()?; |
332 info!("Connected: {}", addr); |
371 info!("Connected: {}", addr); |
333 |
372 |
334 let client_id = self.server.add_client(); |
373 let client_id = self.server.add_client(); |
335 self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); |
374 self.register_client( |
|
375 poll, |
|
376 client_id, |
|
377 self.create_client_socket(client_socket)?, |
|
378 addr, |
|
379 ); |
336 self.flush_server_messages(); |
380 self.flush_server_messages(); |
337 |
381 |
338 Ok(()) |
382 Ok(()) |
339 } |
383 } |
340 |
384 |
341 fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> { |
385 fn operation_failed( |
|
386 &mut self, |
|
387 poll: &Poll, |
|
388 client_id: ClientId, |
|
389 error: &Error, |
|
390 msg: &str, |
|
391 ) -> io::Result<()> { |
342 let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
392 let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
343 client.peer_addr |
393 client.peer_addr |
344 } else { |
394 } else { |
345 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) |
395 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) |
346 }; |
396 }; |
347 debug!("{}({}): {}", msg, addr, error); |
397 debug!("{}({}): {}", msg, addr, error); |
348 self.client_error(poll, client_id) |
398 self.client_error(poll, client_id) |
349 } |
399 } |
350 |
400 |
351 pub fn client_readable(&mut self, poll: &Poll, |
401 pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
352 client_id: ClientId) -> io::Result<()> { |
402 let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
353 let messages = |
403 client.read() |
354 if let Some(ref mut client) = self.clients.get_mut(client_id) { |
404 } else { |
355 client.read() |
405 warn!("invalid readable client: {}", client_id); |
356 } else { |
406 Ok((Vec::new(), NetworkClientState::Idle)) |
357 warn!("invalid readable client: {}", client_id); |
407 }; |
358 Ok((Vec::new(), NetworkClientState::Idle)) |
|
359 }; |
|
360 |
408 |
361 match messages { |
409 match messages { |
362 Ok((messages, state)) => { |
410 Ok((messages, state)) => { |
363 for message in messages { |
411 for message in messages { |
364 self.server.handle_msg(client_id, message); |
412 self.server.handle_msg(client_id, message); |
365 } |
413 } |
366 match state { |
414 match state { |
367 NetworkClientState::NeedsRead => { |
415 NetworkClientState::NeedsRead => { |
368 self.pending.insert((client_id, state)); |
416 self.pending.insert((client_id, state)); |
369 }, |
417 } |
370 NetworkClientState::Closed => |
418 NetworkClientState::Closed => self.client_error(&poll, client_id)?, |
371 self.client_error(&poll, client_id)?, |
|
372 _ => {} |
419 _ => {} |
373 }; |
420 }; |
374 } |
421 } |
375 Err(e) => self.operation_failed( |
422 Err(e) => self.operation_failed( |
376 poll, client_id, &e, |
423 poll, |
377 "Error while reading from client socket")? |
424 client_id, |
|
425 &e, |
|
426 "Error while reading from client socket", |
|
427 )?, |
378 } |
428 } |
379 |
429 |
380 self.flush_server_messages(); |
430 self.flush_server_messages(); |
381 |
431 |
382 if !self.server.removed_clients.is_empty() { |
432 if !self.server.removed_clients.is_empty() { |
387 } |
437 } |
388 |
438 |
389 Ok(()) |
439 Ok(()) |
390 } |
440 } |
391 |
441 |
392 pub fn client_writable(&mut self, poll: &Poll, |
442 pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
393 client_id: ClientId) -> io::Result<()> { |
443 let result = if let Some(ref mut client) = self.clients.get_mut(client_id) { |
394 let result = |
444 client.write() |
395 if let Some(ref mut client) = self.clients.get_mut(client_id) { |
445 } else { |
396 client.write() |
446 warn!("invalid writable client: {}", client_id); |
397 } else { |
447 Ok(((), NetworkClientState::Idle)) |
398 warn!("invalid writable client: {}", client_id); |
448 }; |
399 Ok(((), NetworkClientState::Idle)) |
|
400 }; |
|
401 |
449 |
402 match result { |
450 match result { |
403 Ok(((), state)) if state == NetworkClientState::NeedsWrite => { |
451 Ok(((), state)) if state == NetworkClientState::NeedsWrite => { |
404 self.pending.insert((client_id, state)); |
452 self.pending.insert((client_id, state)); |
405 }, |
453 } |
406 Ok(_) => {} |
454 Ok(_) => {} |
407 Err(e) => self.operation_failed( |
455 Err(e) => { |
408 poll, client_id, &e, |
456 self.operation_failed(poll, client_id, &e, "Error while writing to client socket")? |
409 "Error while writing to client socket")? |
457 } |
410 } |
458 } |
411 |
459 |
412 Ok(()) |
460 Ok(()) |
413 } |
461 } |
414 |
462 |
415 pub fn client_error(&mut self, poll: &Poll, |
463 pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { |
416 client_id: ClientId) -> io::Result<()> { |
|
417 self.deregister_client(poll, client_id); |
464 self.deregister_client(poll, client_id); |
418 self.server.client_lost(client_id); |
465 self.server.client_lost(client_id); |
419 |
466 |
420 Ok(()) |
467 Ok(()) |
421 } |
468 } |
428 if self.has_pending_operations() { |
475 if self.has_pending_operations() { |
429 let mut cache = replace(&mut self.pending_cache, Vec::new()); |
476 let mut cache = replace(&mut self.pending_cache, Vec::new()); |
430 cache.extend(self.pending.drain()); |
477 cache.extend(self.pending.drain()); |
431 for (id, state) in cache.drain(..) { |
478 for (id, state) in cache.drain(..) { |
432 match state { |
479 match state { |
433 NetworkClientState::NeedsRead => |
480 NetworkClientState::NeedsRead => self.client_readable(poll, id)?, |
434 self.client_readable(poll, id)?, |
481 NetworkClientState::NeedsWrite => self.client_writable(poll, id)?, |
435 NetworkClientState::NeedsWrite => |
|
436 self.client_writable(poll, id)?, |
|
437 _ => {} |
482 _ => {} |
438 } |
483 } |
439 } |
484 } |
440 swap(&mut cache, &mut self.pending_cache); |
485 swap(&mut cache, &mut self.pending_cache); |
441 } |
486 } |