# HG changeset patch # User alfadur # Date 1536264752 -10800 # Node ID c8fd12db6215e30df1454fc3fa65b86a637f1769 # Parent 4664da990556286ba4c49c8cdfb81a25bde84205 Add TLS support diff -r 4664da990556 -r c8fd12db6215 gameServer2/Cargo.toml --- a/gameServer2/Cargo.toml Wed Sep 05 19:22:29 2018 +0300 +++ b/gameServer2/Cargo.toml Thu Sep 06 23:12:32 2018 +0300 @@ -8,6 +8,8 @@ [features] official-server = [] +tls-connections = ["openssl"] +default = [] [dependencies] rand = "0.5" @@ -22,6 +24,7 @@ serde = "1.0" serde_yaml = "0.8" serde_derive = "1.0" +openssl = { version = "0.10", optional = true } [dev-dependencies] proptest = "0.8" \ No newline at end of file diff -r 4664da990556 -r c8fd12db6215 gameServer2/src/main.rs --- a/gameServer2/src/main.rs Wed Sep 05 19:22:29 2018 +0300 +++ b/gameServer2/src/main.rs Thu Sep 06 23:12:32 2018 +0300 @@ -17,6 +17,8 @@ #[macro_use] extern crate bitflags; extern crate serde; extern crate serde_yaml; +#[cfg(feature = "tls-connections")] +extern crate openssl; #[macro_use] extern crate serde_derive; //use std::io::*; diff -r 4664da990556 -r c8fd12db6215 gameServer2/src/server/network.rs --- a/gameServer2/src/server/network.rs Wed Sep 05 19:22:29 2018 +0300 +++ b/gameServer2/src/server/network.rs Thu Sep 06 23:12:32 2018 +0300 @@ -1,7 +1,7 @@ extern crate slab; use std::{ - io, io::{Error, ErrorKind, Write}, + io, io::{Error, ErrorKind, Read, Write}, net::{SocketAddr, IpAddr, Ipv4Addr}, collections::HashSet, mem::{swap, replace} @@ -22,6 +22,15 @@ server::{HWServer}, coretypes::ClientId }; +#[cfg(feature = "tls-connections")] +use openssl::{ + ssl::{ + SslMethod, SslContext, Ssl, SslContextBuilder, + SslVerifyMode, SslFiletype, SslOptions, + SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream + }, + error::ErrorStack +}; const MAX_BYTES_PER_READ: usize = 2048; @@ -35,16 +44,43 @@ type NetworkResult = io::Result<(T, NetworkClientState)>; +#[cfg(not(feature = "tls-connections"))] +pub enum ClientSocket { + Plain(TcpStream) +} + +#[cfg(feature = "tls-connections")] +pub enum ClientSocket { + SslHandshake(Option>), + SslStream(SslStream) +} + +impl ClientSocket { + fn inner(&self) -> &TcpStream { + #[cfg(not(feature = "tls-connections"))] + match self { + ClientSocket::Plain(stream) => stream, + } + + #[cfg(feature = "tls-connections")] + match self { + ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), + ClientSocket::SslHandshake(None) => unreachable!(), + ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() + } + } +} + pub struct NetworkClient { id: ClientId, - socket: TcpStream, + socket: ClientSocket, peer_addr: SocketAddr, decoder: ProtocolDecoder, buf_out: netbuf::Buf } impl NetworkClient { - pub fn new(id: ClientId, socket: TcpStream, peer_addr: SocketAddr) -> NetworkClient { + pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { NetworkClient { id, socket, peer_addr, decoder: ProtocolDecoder::new(), @@ -52,31 +88,32 @@ } } - pub fn read_messages(&mut self) -> NetworkResult> { + fn read_impl(decoder: &mut ProtocolDecoder, source: &mut R, + id: ClientId, addr: &SocketAddr) -> NetworkResult> { let mut bytes_read = 0; let result = loop { - match self.decoder.read_from(&mut self.socket) { + match decoder.read_from(source) { Ok(bytes) => { - debug!("Client {}: read {} bytes", self.id, bytes); + debug!("Client {}: read {} bytes", id, bytes); bytes_read += bytes; if bytes == 0 { let result = if bytes_read == 0 { - info!("EOF for client {} ({})", self.id, self.peer_addr); + info!("EOF for client {} ({})", id, addr); (Vec::new(), NetworkClientState::Closed) } else { - (self.decoder.extract_messages(), NetworkClientState::NeedsRead) + (decoder.extract_messages(), NetworkClientState::NeedsRead) }; break Ok(result); } else if bytes_read >= MAX_BYTES_PER_READ { - break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead)) + break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) } } Err(ref error) if error.kind() == ErrorKind::WouldBlock => { let messages = if bytes_read == 0 { Vec::new() } else { - self.decoder.extract_messages() + decoder.extract_messages() }; break Ok((messages, NetworkClientState::Idle)); } @@ -84,14 +121,48 @@ break Err(error) } }; - self.decoder.sweep(); + decoder.sweep(); result } - pub fn flush(&mut self) -> NetworkResult<()> { + pub fn read(&mut self) -> NetworkResult> { + #[cfg(not(feature = "tls-connections"))] + match self.socket { + ClientSocket::Plain(ref mut stream) => + NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), + } + + #[cfg(feature = "tls-connections")] + match self.socket { + ClientSocket::SslHandshake(ref mut handshake_opt) => { + let mut handshake = std::mem::replace(handshake_opt, None).unwrap(); + + match handshake.handshake() { + Ok(stream) => { + debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); + self.socket = ClientSocket::SslStream(stream); + + Ok((Vec::new(), NetworkClientState::Idle)) + } + Err(HandshakeError::WouldBlock(new_handshake)) => { + *handshake_opt = Some(new_handshake); + Ok((Vec::new(), NetworkClientState::Idle)) + } + Err(e) => { + debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); + Err(Error::new(ErrorKind::Other, "Connection failure")) + } + } + }, + ClientSocket::SslStream(ref mut stream) => + NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) + } + } + + fn write_impl(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { let result = loop { - match self.buf_out.write_to(&mut self.socket) { - Ok(bytes) if self.buf_out.is_empty() || bytes == 0 => + match buf_out.write_to(destination) { + Ok(bytes) if buf_out.is_empty() || bytes == 0 => break Ok(((), NetworkClientState::Idle)), Ok(_) => (), Err(ref error) if error.kind() == ErrorKind::Interrupted @@ -102,7 +173,28 @@ break Err(error) } }; - self.socket.flush()?; + result + } + + pub fn write(&mut self) -> NetworkResult<()> { + let result = { + #[cfg(not(feature = "tls-connections"))] + match self.socket { + ClientSocket::Plain(ref mut stream) => + NetworkClient::write_impl(&mut self.buf_out, stream) + } + + #[cfg(feature = "tls-connections")] { + match self.socket { + ClientSocket::SslHandshake(_) => + Ok(((), NetworkClientState::Idle)), + ClientSocket::SslStream(ref mut stream) => + NetworkClient::write_impl(&mut self.buf_out, stream) + } + } + }; + + self.socket.inner().flush()?; result } @@ -119,12 +211,19 @@ } } +#[cfg(feature = "tls-connections")] +struct ServerSsl { + context: SslContext +} + pub struct NetworkLayer { listener: TcpListener, server: HWServer, clients: Slab, pending: HashSet<(ClientId, NetworkClientState)>, - pending_cache: Vec<(ClientId, NetworkClientState)> + pending_cache: Vec<(ClientId, NetworkClientState)>, + #[cfg(feature = "tls-connections")] + ssl: ServerSsl } impl NetworkLayer { @@ -133,7 +232,24 @@ let clients = Slab::with_capacity(clients_limit); let pending = HashSet::with_capacity(2 * clients_limit); let pending_cache = Vec::with_capacity(2 * clients_limit); - NetworkLayer {listener, server, clients, pending, pending_cache} + + NetworkLayer { + listener, server, clients, pending, pending_cache, + #[cfg(feature = "tls-connections")] + ssl: NetworkLayer::create_ssl_context() + } + } + + #[cfg(feature = "tls-connections")] + fn create_ssl_context() -> ServerSsl { + let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + builder.set_read_ahead(true); + builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); + builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); + builder.set_options(SslOptions::NO_COMPRESSION); + builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); + ServerSsl { context: builder.build() } } pub fn register_server(&self, poll: &Poll) -> io::Result<()> { @@ -144,7 +260,7 @@ fn deregister_client(&mut self, poll: &Poll, id: ClientId) { let mut client_exists = false; if let Some(ref client) = self.clients.get(id) { - poll.deregister(&client.socket) + poll.deregister(client.socket.inner()) .expect("could not deregister socket"); info!("client {} ({}) removed", client.id, client.peer_addr); client_exists = true; @@ -154,8 +270,8 @@ } } - fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: TcpStream, addr: SocketAddr) { - poll.register(&client_socket, Token(id), + fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { + poll.register(client_socket.inner(), Token(id), Ready::readable() | Ready::writable(), PollOpt::edge()) .expect("could not register socket with event loop"); @@ -180,12 +296,34 @@ } } + fn create_client_socket(&self, socket: TcpStream) -> io::Result { + #[cfg(not(feature = "tls-connections"))] { + Ok(ClientSocket::Plain(socket)) + } + + #[cfg(feature = "tls-connections")] { + let ssl = Ssl::new(&self.ssl.context).unwrap(); + let mut builder = SslStreamBuilder::new(ssl, socket); + builder.set_accept_state(); + match builder.handshake() { + Ok(stream) => + Ok(ClientSocket::SslStream(stream)), + Err(HandshakeError::WouldBlock(stream)) => + Ok(ClientSocket::SslHandshake(Some(stream))), + Err(e) => { + debug!("OpenSSL handshake failed: {}", e); + Err(Error::new(ErrorKind::Other, "Connection failure")) + } + } + } + } + pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { let (client_socket, addr) = self.listener.accept()?; info!("Connected: {}", addr); let client_id = self.server.add_client(); - self.register_client(poll, client_id, client_socket, addr); + self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); self.flush_server_messages(); Ok(()) @@ -205,7 +343,7 @@ client_id: ClientId) -> io::Result<()> { let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.read_messages() + client.read() } else { warn!("invalid readable client: {}", client_id); Ok((Vec::new(), NetworkClientState::Idle)) @@ -246,7 +384,7 @@ client_id: ClientId) -> io::Result<()> { let result = if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.flush() + client.write() } else { warn!("invalid writable client: {}", client_id); Ok(((), NetworkClientState::Idle))