# HG changeset patch # User alfadur # Date 1679617568 -10800 # Node ID c5c53ebb2d91c1bc3e4cc52b964fb7978a012594 # Parent cd3d16905e0ee6b1f6a3e13c310277f7c66c79d0 add back server TLS support diff -r cd3d16905e0e -r c5c53ebb2d91 rust/hedgewars-server/Cargo.toml --- a/rust/hedgewars-server/Cargo.toml Thu Mar 23 23:41:26 2023 +0300 +++ b/rust/hedgewars-server/Cargo.toml Fri Mar 24 03:26:08 2023 +0300 @@ -5,7 +5,8 @@ authors = [ "Andrey Korotaev " ] [features] -official-server = ["mysql_async", "sha1"] +tls-connections = ["tokio-native-tls"] +official-server = ["mysql_async", "sha1", "tls-connections"] default = [] [dependencies] @@ -25,6 +26,7 @@ sha1 = { version = "0.10.0", optional = true } slab = "0.4" tokio = { version = "1.16", features = ["full"]} +tokio-native-tls = { version = "0.3", optional = true } hedgewars-network-protocol = { path = "../hedgewars-network-protocol" } diff -r cd3d16905e0e -r c5c53ebb2d91 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Thu Mar 23 23:41:26 2023 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Fri Mar 24 03:26:08 2023 +0300 @@ -1,16 +1,21 @@ use bytes::{Buf, Bytes}; use log::*; use slab::Slab; +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{ iter::Iterator, net::{IpAddr, SocketAddr}, time::Duration, }; use tokio::{ - io::AsyncReadExt, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, net::{TcpListener, TcpStream}, sync::mpsc::{channel, Receiver, Sender}, }; +#[cfg(feature = "tls-connections")] +use tokio_native_tls::{TlsAcceptor, TlsStream}; use crate::{ core::{ @@ -25,7 +30,6 @@ use hedgewars_network_protocol::{ messages::HwServerMessage::Redirect, messages::*, parser::server_message, }; -use tokio::io::AsyncWriteExt; const PING_TIMEOUT: Duration = Duration::from_secs(15); @@ -56,9 +60,65 @@ } } +enum ClientStream { + Tcp(TcpStream), + #[cfg(feature = "tls-connections")] + Tls(TlsStream), +} + +impl Unpin for ClientStream {} + +impl AsyncRead for ClientStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + use ClientStream::*; + match Pin::into_inner(self) { + Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + #[cfg(feature = "tls-connections")] + Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for ClientStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + use ClientStream::*; + match Pin::into_inner(self) { + Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + #[cfg(feature = "tls-connections")] + Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use ClientStream::*; + match Pin::into_inner(self) { + Tcp(stream) => Pin::new(stream).poll_flush(cx), + #[cfg(feature = "tls-connections")] + Tls(stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use ClientStream::*; + match Pin::into_inner(self) { + Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + #[cfg(feature = "tls-connections")] + Tls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + struct NetworkClient { id: ClientId, - socket: TcpStream, + stream: ClientStream, receiver: Receiver, peer_addr: SocketAddr, decoder: ProtocolDecoder, @@ -67,27 +127,27 @@ impl NetworkClient { fn new( id: ClientId, - socket: TcpStream, + stream: ClientStream, peer_addr: SocketAddr, receiver: Receiver, ) -> Self { Self { id, - socket, + stream, peer_addr, receiver, decoder: ProtocolDecoder::new(PING_TIMEOUT), } } - async fn read( - socket: &mut TcpStream, + async fn read( + stream: &mut T, decoder: &mut ProtocolDecoder, ) -> protocol::Result { - let result = decoder.read_from(socket).await; + let result = decoder.read_from(stream).await; if matches!(result, Err(ProtocolError::Timeout)) { - if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { - decoder.read_from(socket).await + if Self::write(stream, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { + decoder.read_from(stream).await } else { Err(ProtocolError::Eof) } @@ -96,8 +156,8 @@ } } - async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool { - !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0) + async fn write(stream: &mut T, mut data: Bytes) -> bool { + !data.has_remaining() || matches!(stream.write_buf(&mut data).await, Ok(n) if n > 0) } async fn run(mut self, sender: Sender) { @@ -111,7 +171,7 @@ tokio::select! { server_message = self.receiver.recv() => { match server_message { - Some(message) => if !Self::write(&mut self.socket, message).await { + Some(message) => if !Self::write(&mut self.stream, message).await { sender.send(Error("Connection reset by peer".to_string())).await; break; } @@ -120,7 +180,7 @@ } } } - client_message = Self::read(&mut self.socket, &mut self.decoder) => { + client_message = Self::read(&mut self.stream, &mut self.decoder) => { match client_message { Ok(message) => { if !sender.send(Message(message)).await { @@ -130,7 +190,7 @@ Err(e) => { sender.send(Error(format!("{}", e))).await; if matches!(e, ProtocolError::Timeout) { - Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; + Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; } break; } @@ -141,8 +201,16 @@ } } +#[cfg(feature = "tls-connections")] +struct TlsListener { + listener: TcpListener, + acceptor: TlsAcceptor, +} + pub struct NetworkLayer { listener: TcpListener, + #[cfg(feature = "tls-connections")] + tls: TlsListener, server_state: ServerState, clients: Slab>, } @@ -151,36 +219,82 @@ pub async fn run(&mut self) { let (update_tx, mut update_rx) = channel(128); - loop { - tokio::select! { - Ok((stream, addr)) = self.listener.accept() => { - if let Some(client) = self.create_client(stream, addr).await { + async fn accept_plain_branch( + layer: &mut NetworkLayer, + value: (TcpStream, SocketAddr), + update_tx: Sender, + ) { + let (stream, addr) = value; + if let Some(client) = layer.create_client(ClientStream::Tcp(stream), addr).await { + tokio::spawn(client.run(update_tx)); + } + } + + #[cfg(feature = "tls-connections")] + async fn accept_tls_branch( + layer: &mut NetworkLayer, + value: (TcpStream, SocketAddr), + update_tx: Sender, + ) { + let (stream, addr) = value; + match layer.tls.acceptor.accept(stream).await { + Ok(stream) => { + if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await + { tokio::spawn(client.run(update_tx.clone())); } } - client_message = update_rx.recv(), if !self.clients.is_empty() => { - use ClientUpdateData::*; - match client_message { - Some(ClientUpdate{ client_id, data: Message(message) } ) => { - self.handle_message(client_id, message).await; - } - Some(ClientUpdate{ client_id, data: Error(e) } ) => { - let mut response = handlers::Response::new(client_id); - info!("Client {} error: {:?}", client_id, e); - response.remove_client(client_id); - handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); - self.handle_response(response).await; - } - None => unreachable!() - } + Err(e) => { + warn!("Unable to establish TLS connection: {}", e); + } + } + } + + async fn client_message_branch( + layer: &mut NetworkLayer, + client_message: Option, + ) { + use ClientUpdateData::*; + match client_message { + Some(ClientUpdate { + client_id, + data: Message(message), + }) => { + layer.handle_message(client_id, message).await; } + Some(ClientUpdate { + client_id, + data: Error(e), + }) => { + let mut response = handlers::Response::new(client_id); + info!("Client {} error: {:?}", client_id, e); + response.remove_client(client_id); + handlers::handle_client_loss(&mut layer.server_state, client_id, &mut response); + layer.handle_response(response).await; + } + None => unreachable!(), + } + } + + loop { + #[cfg(not(feature = "tls-connections"))] + tokio::select! { + Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await, + client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await + } + + #[cfg(feature = "tls-connections")] + tokio::select! { + Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await, + Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await, + client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await } } } async fn create_client( &mut self, - stream: TcpStream, + stream: ClientStream, addr: SocketAddr, ) -> Option { let entry = self.clients.vacant_entry(); @@ -263,6 +377,10 @@ pub struct NetworkLayerBuilder { listener: Option, + #[cfg(feature = "tls-connections")] + tls_listener: Option, + #[cfg(feature = "tls-connections")] + tls_acceptor: Option, clients_capacity: usize, rooms_capacity: usize, } @@ -273,6 +391,10 @@ clients_capacity: 1024, rooms_capacity: 512, listener: None, + #[cfg(feature = "tls-connections")] + tls_listener: None, + #[cfg(feature = "tls-connections")] + tls_acceptor: None, } } } @@ -285,6 +407,22 @@ } } + #[cfg(feature = "tls-connections")] + pub fn with_tls_acceptor(self, listener: TlsAcceptor) -> Self { + Self { + tls_acceptor: Option::from(listener), + ..self + } + } + + #[cfg(feature = "tls-connections")] + pub fn with_tls_listener(self, listener: TlsAcceptor) -> Self { + Self { + tls_acceptor: Option::from(listener), + ..self + } + } + pub fn build(self) -> NetworkLayer { let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity); @@ -292,6 +430,11 @@ NetworkLayer { listener: self.listener.expect("No listener provided"), + #[cfg(feature = "tls-connections")] + tls: TlsListener { + listener: self.tls_listener.expect("No TLS listener provided"), + acceptor: self.tls_acceptor.expect("No TLS acceptor provided"), + }, server_state, clients, }