rust/hedgewars-server/src/server/network.rs
author alfadur
Tue, 01 Feb 2022 02:23:35 +0300
changeset 15831 7d0f747afcb8
parent 15804 747278149393
child 15832 a4d505a32879
permissions -rw-r--r--
move server network to tokio

use bytes::{Buf, Bytes};
use log::*;
use slab::Slab;
use std::{
    collections::HashSet,
    io,
    io::{Error, ErrorKind, Read, Write},
    iter::Iterator,
    mem::{replace, swap},
    net::{IpAddr, Ipv4Addr, SocketAddr},
    num::NonZeroU32,
    time::Duration,
    time::Instant,
};
use tokio::{
    io::AsyncReadExt,
    net::{TcpListener, TcpStream},
    sync::mpsc::{channel, Receiver, Sender},
};

use crate::{
    core::{
        events::{TimedEvents, Timeout},
        types::ClientId,
    },
    handlers,
    handlers::{IoResult, IoTask, ServerState},
    protocol::ProtocolDecoder,
    utils,
};
use hedgewars_network_protocol::{
    messages::HwServerMessage::Redirect, messages::*, parser::server_message,
};
use tokio::io::AsyncWriteExt;

enum ClientUpdateData {
    Message(HwProtocolMessage),
    Error(String),
}

struct ClientUpdate {
    client_id: ClientId,
    data: ClientUpdateData,
}

struct ClientUpdateSender {
    client_id: ClientId,
    sender: Sender<ClientUpdate>,
}

impl ClientUpdateSender {
    async fn send(&mut self, data: ClientUpdateData) -> bool {
        self.sender
            .send(ClientUpdate {
                client_id: self.client_id,
                data,
            })
            .await
            .is_ok()
    }
}

struct NetworkClient {
    id: ClientId,
    socket: TcpStream,
    receiver: Receiver<Bytes>,
    peer_addr: SocketAddr,
    decoder: ProtocolDecoder,
}

impl NetworkClient {
    fn new(
        id: ClientId,
        socket: TcpStream,
        peer_addr: SocketAddr,
        receiver: Receiver<Bytes>,
    ) -> Self {
        Self {
            id,
            socket,
            peer_addr,
            receiver,
            decoder: ProtocolDecoder::new(),
        }
    }

    async fn read(&mut self) -> Option<HwProtocolMessage> {
        self.decoder.read_from(&mut self.socket).await
    }

    async fn write(&mut self, mut data: Bytes) -> bool {
        !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0)
    }

    async fn run(mut self, sender: Sender<ClientUpdate>) {
        use ClientUpdateData::*;
        let mut sender = ClientUpdateSender {
            client_id: self.id,
            sender,
        };

        loop {
            tokio::select! {
                server_message = self.receiver.recv() => {
                    match server_message {
                        Some(message) => if !self.write(message).await {
                            sender.send(Error("Connection reset by peer".to_string())).await;
                            break;
                        }
                        None => {
                            break;
                        }
                    }
                }
                client_message = self.decoder.read_from(&mut self.socket) => {
                     match client_message {
                        Some(message) => {
                            if !sender.send(Message(message)).await {
                                break;
                            }
                        }
                        None => {
                            sender.send(Error("Connection reset by peer".to_string())).await;
                            break;
                        }
                    }
                }
            }
        }
    }
}

pub struct NetworkLayer {
    listener: TcpListener,
    server_state: ServerState,
    clients: Slab<Sender<Bytes>>,
}

impl NetworkLayer {
    pub async fn run(&mut self) {
        let (update_tx, mut update_rx) = channel(128);

        loop {
            tokio::select! {
                Ok((stream, addr)) = self.listener.accept() => {
                    if let Some(client) = self.create_client(stream, addr).await {
                        tokio::spawn(client.run(update_tx.clone()));
                    }
                }
                client_message = update_rx.recv(), if !self.clients.is_empty() => {
                    use ClientUpdateData::*;
                    match client_message {
                        Some(ClientUpdate{ client_id, data: Message(message) } ) => {
                            self.handle_message(client_id, message).await;
                        }
                        Some(ClientUpdate{ client_id, .. } ) => {
                            let mut response = handlers::Response::new(client_id);
                            handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
                            self.handle_response(response).await;
                        }
                        None => unreachable!()
                    }
                }
            }
        }
    }

    async fn create_client(
        &mut self,
        stream: TcpStream,
        addr: SocketAddr,
    ) -> Option<NetworkClient> {
        let entry = self.clients.vacant_entry();
        let client_id = entry.key();
        let (tx, rx) = channel(16);
        entry.insert(tx);

        let client = NetworkClient::new(client_id, stream, addr, rx);

        info!("client {} ({}) added", client.id, client.peer_addr);

        let mut response = handlers::Response::new(client_id);

        let added = if let IpAddr::V4(addr) = client.peer_addr.ip() {
            handlers::handle_client_accept(
                &mut self.server_state,
                client_id,
                &mut response,
                addr.octets(),
                addr.is_loopback(),
            )
        } else {
            todo!("implement something")
        };

        self.handle_response(response).await;

        if added {
            Some(client)
        } else {
            None
        }
    }

    async fn handle_message(&mut self, client_id: ClientId, message: HwProtocolMessage) {
        debug!("Handling message {:?} for client {}", message, client_id);
        let mut response = handlers::Response::new(client_id);
        handlers::handle(&mut self.server_state, client_id, &mut response, message);
        self.handle_response(response).await;
    }

    async fn handle_response(&mut self, mut response: handlers::Response) {
        if response.is_empty() {
            return;
        }

        debug!("{} pending server messages", response.len());
        let output = response.extract_messages(&mut self.server_state.server);
        for (clients, message) in output {
            debug!("Message {:?} to {:?}", message, clients);
            Self::send_message(&mut self.clients, message, clients.iter().cloned()).await;
        }

        for client_id in response.extract_removed_clients() {
            if self.clients.contains(client_id) {
                self.clients.remove(client_id);
            }
            info!("Client {} removed", client_id);
        }
    }

    async fn send_message<I>(
        clients: &mut Slab<Sender<Bytes>>,
        message: HwServerMessage,
        to_clients: I,
    ) where
        I: Iterator<Item = ClientId>,
    {
        let msg_string = message.to_raw_protocol();
        let bytes = Bytes::copy_from_slice(msg_string.as_bytes());
        for client_id in to_clients {
            if let Some(client) = clients.get_mut(client_id) {
                if !client.send(bytes.clone()).await.is_ok() {
                    clients.remove(client_id);
                }
            }
        }
    }
}

pub struct NetworkLayerBuilder {
    listener: Option<TcpListener>,
    clients_capacity: usize,
    rooms_capacity: usize,
}

impl Default for NetworkLayerBuilder {
    fn default() -> Self {
        Self {
            clients_capacity: 1024,
            rooms_capacity: 512,
            listener: None,
        }
    }
}

impl NetworkLayerBuilder {
    pub fn with_listener(self, listener: TcpListener) -> Self {
        Self {
            listener: Some(listener),
            ..self
        }
    }

    pub fn build(self) -> NetworkLayer {
        let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);

        let clients = Slab::with_capacity(self.clients_capacity);

        NetworkLayer {
            listener: self.listener.expect("No listener provided"),
            server_state,
            clients,
        }
    }
}