rust/hedgewars-server/src/server/network.rs
author Wuzzy <Wuzzy2@mail.ru>
Mon, 07 Oct 2019 14:18:31 +0200
changeset 15447 6031c0cfec89
parent 15181 f6115638aa92
permissions -rw-r--r--
Make sure AI doesn't skip after collecting ammo. Fixes bug #796

extern crate slab;

use std::{
    collections::HashSet,
    io,
    io::{Error, ErrorKind, Read, Write},
    mem::{replace, swap},
    net::{IpAddr, Ipv4Addr, SocketAddr},
};

use log::*;
use mio::{
    net::{TcpListener, TcpStream},
    Evented, Poll, PollOpt, Ready, Token,
};
use mio_extras::timer;
use netbuf;
use slab::Slab;

use crate::{
    core::{server::HwServer, types::ClientId},
    handlers,
    handlers::{IoResult, IoTask},
    protocol::{messages::HwServerMessage::Redirect, messages::*, ProtocolDecoder},
    utils,
};

#[cfg(feature = "official-server")]
use super::io::{IoThread, RequestId};

#[cfg(feature = "tls-connections")]
use openssl::{
    error::ErrorStack,
    ssl::{
        HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
        SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
    },
};
use std::time::Duration;

const MAX_BYTES_PER_READ: usize = 2048;
const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(30);
const PING_PROBES_COUNT: u8 = 2;

#[derive(Hash, Eq, PartialEq, Copy, Clone)]
pub enum NetworkClientState {
    Idle,
    NeedsWrite,
    NeedsRead,
    Closed,
    #[cfg(feature = "tls-connections")]
    Connected,
}

type NetworkResult<T> = io::Result<(T, NetworkClientState)>;

pub enum ClientSocket {
    Plain(TcpStream),
    #[cfg(feature = "tls-connections")]
    SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
    #[cfg(feature = "tls-connections")]
    SslStream(SslStream<TcpStream>),
}

impl ClientSocket {
    fn inner(&self) -> &TcpStream {
        match self {
            ClientSocket::Plain(stream) => stream,
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslHandshake(None) => unreachable!(),
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
        }
    }
}

pub struct NetworkClient {
    id: ClientId,
    socket: ClientSocket,
    peer_addr: SocketAddr,
    decoder: ProtocolDecoder,
    buf_out: netbuf::Buf,
    timeout: timer::Timeout,
    pending_close: bool,
}

impl NetworkClient {
    pub fn new(
        id: ClientId,
        socket: ClientSocket,
        peer_addr: SocketAddr,
        timeout: timer::Timeout,
    ) -> NetworkClient {
        NetworkClient {
            id,
            socket,
            peer_addr,
            decoder: ProtocolDecoder::new(),
            buf_out: netbuf::Buf::new(),
            timeout,
            pending_close: false,
        }
    }

    #[cfg(feature = "tls-connections")]
    fn handshake_impl(
        &mut self,
        handshake: MidHandshakeSslStream<TcpStream>,
    ) -> io::Result<NetworkClientState> {
        match handshake.handshake() {
            Ok(stream) => {
                self.socket = ClientSocket::SslStream(stream);
                debug!(
                    "TLS handshake with {} ({}) completed",
                    self.id, self.peer_addr
                );
                Ok(NetworkClientState::Connected)
            }
            Err(HandshakeError::WouldBlock(new_handshake)) => {
                self.socket = ClientSocket::SslHandshake(Some(new_handshake));
                Ok(NetworkClientState::Idle)
            }
            Err(HandshakeError::Failure(new_handshake)) => {
                self.socket = ClientSocket::SslHandshake(Some(new_handshake));
                debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
                Err(Error::new(ErrorKind::Other, "Connection failure"))
            }
            Err(HandshakeError::SetupFailure(_)) => unreachable!(),
        }
    }

    fn read_impl<R: Read>(
        decoder: &mut ProtocolDecoder,
        source: &mut R,
        id: ClientId,
        addr: &SocketAddr,
    ) -> NetworkResult<Vec<HwProtocolMessage>> {
        let mut bytes_read = 0;
        let result = loop {
            match decoder.read_from(source) {
                Ok(bytes) => {
                    debug!("Client {}: read {} bytes", id, bytes);
                    bytes_read += bytes;
                    if bytes == 0 {
                        let result = if bytes_read == 0 {
                            info!("EOF for client {} ({})", id, addr);
                            (Vec::new(), NetworkClientState::Closed)
                        } else {
                            (decoder.extract_messages(), NetworkClientState::NeedsRead)
                        };
                        break Ok(result);
                    } else if bytes_read >= MAX_BYTES_PER_READ {
                        break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead));
                    }
                }
                Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
                    let messages = if bytes_read == 0 {
                        Vec::new()
                    } else {
                        decoder.extract_messages()
                    };
                    break Ok((messages, NetworkClientState::Idle));
                }
                Err(error) => break Err(error),
            }
        };
        result
    }

    pub fn read(&mut self) -> NetworkResult<Vec<HwProtocolMessage>> {
        match self.socket {
            ClientSocket::Plain(ref mut stream) => {
                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
            }
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslHandshake(ref mut handshake_opt) => {
                let handshake = std::mem::replace(handshake_opt, None).unwrap();
                Ok((Vec::new(), self.handshake_impl(handshake)?))
            }
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslStream(ref mut stream) => {
                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
            }
        }
    }

    fn write_impl<W: Write>(
        buf_out: &mut netbuf::Buf,
        destination: &mut W,
        close_on_empty: bool,
    ) -> NetworkResult<()> {
        let result = loop {
            match buf_out.write_to(destination) {
                Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
                    let status = if buf_out.is_empty() && close_on_empty {
                        NetworkClientState::Closed
                    } else {
                        NetworkClientState::Idle
                    };
                    break Ok(((), status));
                }
                Ok(_) => (),
                Err(ref error)
                    if error.kind() == ErrorKind::Interrupted
                        || error.kind() == ErrorKind::WouldBlock =>
                {
                    break Ok(((), NetworkClientState::NeedsWrite));
                }
                Err(error) => break Err(error),
            }
        };
        result
    }

    pub fn write(&mut self) -> NetworkResult<()> {
        let result = match self.socket {
            ClientSocket::Plain(ref mut stream) => {
                NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
            }
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslHandshake(ref mut handshake_opt) => {
                let handshake = std::mem::replace(handshake_opt, None).unwrap();
                Ok(((), self.handshake_impl(handshake)?))
            }
            #[cfg(feature = "tls-connections")]
            ClientSocket::SslStream(ref mut stream) => {
                NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
            }
        };

        self.socket.inner().flush()?;
        result
    }

    pub fn send_raw_msg(&mut self, msg: &[u8]) {
        self.buf_out.write_all(msg).unwrap();
    }

    pub fn send_string(&mut self, msg: &str) {
        self.send_raw_msg(&msg.as_bytes());
    }

    pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
        replace(&mut self.timeout, timeout)
    }

    pub fn has_pending_sends(&self) -> bool {
        !self.buf_out.is_empty()
    }
}

#[cfg(feature = "tls-connections")]
struct ServerSsl {
    listener: TcpListener,
    context: SslContext,
}

#[cfg(feature = "official-server")]
pub struct IoLayer {
    next_request_id: RequestId,
    request_queue: Vec<(RequestId, ClientId)>,
    io_thread: IoThread,
}

#[cfg(feature = "official-server")]
impl IoLayer {
    fn new() -> Self {
        Self {
            next_request_id: 0,
            request_queue: vec![],
            io_thread: IoThread::new(),
        }
    }

    fn send(&mut self, client_id: ClientId, task: IoTask) {
        let request_id = self.next_request_id;
        self.next_request_id += 1;
        self.request_queue.push((request_id, client_id));
        self.io_thread.send(request_id, task);
    }

    fn try_recv(&mut self) -> Option<(ClientId, IoResult)> {
        let (request_id, result) = self.io_thread.try_recv()?;
        if let Some(index) = self
            .request_queue
            .iter()
            .position(|(id, _)| *id == request_id)
        {
            let (_, client_id) = self.request_queue.swap_remove(index);
            Some((client_id, result))
        } else {
            None
        }
    }

    fn cancel(&mut self, client_id: ClientId) {
        let mut index = 0;
        while index < self.request_queue.len() {
            if self.request_queue[index].1 == client_id {
                self.request_queue.swap_remove(index);
            } else {
                index += 1;
            }
        }
    }
}

enum TimeoutEvent {
    SendPing { probes_count: u8 },
    DropClient,
}

struct TimerData(TimeoutEvent, ClientId);

pub struct NetworkLayer {
    listener: TcpListener,
    server: HwServer,
    clients: Slab<NetworkClient>,
    pending: HashSet<(ClientId, NetworkClientState)>,
    pending_cache: Vec<(ClientId, NetworkClientState)>,
    #[cfg(feature = "tls-connections")]
    ssl: ServerSsl,
    #[cfg(feature = "official-server")]
    io: IoLayer,
    timer: timer::Timer<TimerData>,
}

fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
    poll.register(evented, token, Ready::readable(), PollOpt::edge())
}

fn create_ping_timeout(
    timer: &mut timer::Timer<TimerData>,
    probes_count: u8,
    client_id: ClientId,
) -> timer::Timeout {
    timer.set_timeout(
        SEND_PING_TIMEOUT,
        TimerData(TimeoutEvent::SendPing { probes_count }, client_id),
    )
}

fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
    timer.set_timeout(
        DROP_CLIENT_TIMEOUT,
        TimerData(TimeoutEvent::DropClient, client_id),
    )
}

impl NetworkLayer {
    pub fn register(&self, poll: &Poll) -> io::Result<()> {
        register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
        #[cfg(feature = "tls-connections")]
        register_read(poll, &self.ssl.listener, utils::SECURE_SERVER_TOKEN)?;
        register_read(poll, &self.timer, utils::TIMER_TOKEN)?;

        #[cfg(feature = "official-server")]
        self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;

        Ok(())
    }

    fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) {
        if let Some(ref mut client) = self.clients.get_mut(id) {
            poll.deregister(client.socket.inner())
                .expect("could not deregister socket");
            if client.has_pending_sends() && !is_error {
                info!(
                    "client {} ({}) pending removal",
                    client.id, client.peer_addr
                );
                client.pending_close = true;
                poll.register(
                    client.socket.inner(),
                    Token(id),
                    Ready::writable(),
                    PollOpt::edge(),
                )
                .unwrap_or_else(|_| {
                    self.clients.remove(id);
                });
            } else {
                info!("client {} ({}) removed", client.id, client.peer_addr);
                self.clients.remove(id);
            }
            #[cfg(feature = "official-server")]
            self.io.cancel(id);
        }
    }

    fn register_client(
        &mut self,
        poll: &Poll,
        client_socket: ClientSocket,
        addr: SocketAddr,
    ) -> io::Result<ClientId> {
        let entry = self.clients.vacant_entry();
        let client_id = entry.key();

        poll.register(
            client_socket.inner(),
            Token(client_id),
            Ready::readable() | Ready::writable(),
            PollOpt::edge(),
        )?;

        let client = NetworkClient::new(
            client_id,
            client_socket,
            addr,
            create_ping_timeout(&mut self.timer, PING_PROBES_COUNT - 1, client_id),
        );
        info!("client {} ({}) added", client.id, client.peer_addr);
        entry.insert(client);

        Ok(client_id)
    }

    fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
        if response.is_empty() {
            return;
        }

        debug!("{} pending server messages", response.len());
        let output = response.extract_messages(&mut self.server);
        for (clients, message) in output {
            debug!("Message {:?} to {:?}", message, clients);
            let msg_string = message.to_raw_protocol();
            for client_id in clients {
                if let Some(client) = self.clients.get_mut(client_id) {
                    client.send_string(&msg_string);
                    self.pending
                        .insert((client_id, NetworkClientState::NeedsWrite));
                }
            }
        }

        for client_id in response.extract_removed_clients() {
            self.deregister_client(poll, client_id, false);
        }

        #[cfg(feature = "official-server")]
        {
            let client_id = response.client_id();
            for task in response.extract_io_tasks() {
                self.io.send(client_id, task);
            }
        }
    }

    pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
        while let Some(TimerData(event, client_id)) = self.timer.poll() {
            match event {
                TimeoutEvent::SendPing { probes_count } => {
                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
                        client.send_string(&HwServerMessage::Ping.to_raw_protocol());
                        client.write()?;
                        let timeout = if probes_count != 0 {
                            create_ping_timeout(&mut self.timer, probes_count - 1, client_id)
                        } else {
                            create_drop_timeout(&mut self.timer, client_id)
                        };
                        client.replace_timeout(timeout);
                    }
                }
                TimeoutEvent::DropClient => {
                    if let Some(ref mut client) = self.clients.get_mut(client_id) {
                        client.send_string(
                            &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(),
                        );
                        client.write();
                    }
                    self.operation_failed(
                        poll,
                        client_id,
                        &ErrorKind::TimedOut.into(),
                        "No ping response",
                    )?;
                }
            }
        }
        Ok(())
    }

    #[cfg(feature = "official-server")]
    pub fn handle_io_result(&mut self, poll: &Poll) -> io::Result<()> {
        while let Some((client_id, result)) = self.io.try_recv() {
            debug!("Handling io result {:?} for client {}", result, client_id);
            let mut response = handlers::Response::new(client_id);
            handlers::handle_io_result(&mut self.server, client_id, &mut response, result);
            self.handle_response(response, poll);
        }
        Ok(())
    }

    fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
        Ok(ClientSocket::Plain(socket))
    }

    #[cfg(feature = "tls-connections")]
    fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
        let ssl = Ssl::new(&self.ssl.context).unwrap();
        let mut builder = SslStreamBuilder::new(ssl, socket);
        builder.set_accept_state();
        match builder.handshake() {
            Ok(stream) => Ok(ClientSocket::SslStream(stream)),
            Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))),
            Err(e) => {
                debug!("OpenSSL handshake failed: {}", e);
                Err(Error::new(ErrorKind::Other, "Connection failure"))
            }
        }
    }

    fn init_client(&mut self, poll: &Poll, client_id: ClientId) {
        let mut response = handlers::Response::new(client_id);

        if let ClientSocket::Plain(_) = self.clients[client_id].socket {
            #[cfg(feature = "tls-connections")]
            response.add(Redirect(self.ssl.listener.local_addr().unwrap().port()).send_self())
        }

        handlers::handle_client_accept(
            &mut self.server,
            client_id,
            &mut response,
            self.clients[client_id].peer_addr.ip().is_loopback(),
        );
        self.handle_response(response, poll);
    }

    pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> {
        match server_token {
            utils::SERVER_TOKEN => {
                let (client_socket, addr) = self.listener.accept()?;
                info!("Connected(plaintext): {}", addr);
                let client_id =
                    self.register_client(poll, self.create_client_socket(client_socket)?, addr)?;
                self.init_client(poll, client_id);
            }
            #[cfg(feature = "tls-connections")]
            utils::SECURE_SERVER_TOKEN => {
                let (client_socket, addr) = self.ssl.listener.accept()?;
                info!("Connected(TLS): {}", addr);
                self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr)?;
            }
            _ => unreachable!(),
        }

        Ok(())
    }

    fn operation_failed(
        &mut self,
        poll: &Poll,
        client_id: ClientId,
        error: &Error,
        msg: &str,
    ) -> io::Result<()> {
        let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
            client.peer_addr
        } else {
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
        };
        debug!("{}({}): {}", msg, addr, error);
        self.client_error(poll, client_id)
    }

    pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
        let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
            let timeout = client.replace_timeout(create_ping_timeout(
                &mut self.timer,
                PING_PROBES_COUNT - 1,
                client_id,
            ));
            self.timer.cancel_timeout(&timeout);
            client.read()
        } else {
            warn!("invalid readable client: {}", client_id);
            Ok((Vec::new(), NetworkClientState::Idle))
        };

        let mut response = handlers::Response::new(client_id);

        match messages {
            Ok((messages, state)) => {
                for message in messages {
                    debug!("Handling message {:?} for client {}", message, client_id);
                    handlers::handle(&mut self.server, client_id, &mut response, message);
                }
                match state {
                    NetworkClientState::NeedsRead => {
                        self.pending.insert((client_id, state));
                    }
                    NetworkClientState::Closed => self.client_error(&poll, client_id)?,
                    #[cfg(feature = "tls-connections")]
                    NetworkClientState::Connected => self.init_client(poll, client_id),
                    _ => {}
                };
            }
            Err(e) => self.operation_failed(
                poll,
                client_id,
                &e,
                "Error while reading from client socket",
            )?,
        }

        self.handle_response(response, poll);

        Ok(())
    }

    pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
        let result = if let Some(ref mut client) = self.clients.get_mut(client_id) {
            client.write()
        } else {
            warn!("invalid writable client: {}", client_id);
            Ok(((), NetworkClientState::Idle))
        };

        match result {
            Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
                self.pending.insert((client_id, state));
            }
            Ok(((), state)) if state == NetworkClientState::Closed => {
                self.deregister_client(poll, client_id, false);
            }
            Ok(_) => (),
            Err(e) => {
                self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
            }
        }

        Ok(())
    }

    pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
        let pending_close = self.clients[client_id].pending_close;
        self.deregister_client(poll, client_id, true);

        if !pending_close {
            let mut response = handlers::Response::new(client_id);
            handlers::handle_client_loss(&mut self.server, client_id, &mut response);
            self.handle_response(response, poll);
        }

        Ok(())
    }

    pub fn has_pending_operations(&self) -> bool {
        !self.pending.is_empty()
    }

    pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
        if self.has_pending_operations() {
            let mut cache = replace(&mut self.pending_cache, Vec::new());
            cache.extend(self.pending.drain());
            for (id, state) in cache.drain(..) {
                match state {
                    NetworkClientState::NeedsRead => self.client_readable(poll, id)?,
                    NetworkClientState::NeedsWrite => self.client_writable(poll, id)?,
                    _ => {}
                }
            }
            swap(&mut cache, &mut self.pending_cache);
        }
        Ok(())
    }
}

pub struct NetworkLayerBuilder {
    listener: Option<TcpListener>,
    secure_listener: Option<TcpListener>,
    clients_capacity: usize,
    rooms_capacity: usize,
}

impl Default for NetworkLayerBuilder {
    fn default() -> Self {
        Self {
            clients_capacity: 1024,
            rooms_capacity: 512,
            listener: None,
            secure_listener: None,
        }
    }
}

impl NetworkLayerBuilder {
    pub fn with_listener(self, listener: TcpListener) -> Self {
        Self {
            listener: Some(listener),
            ..self
        }
    }

    pub fn with_secure_listener(self, listener: TcpListener) -> Self {
        Self {
            secure_listener: Some(listener),
            ..self
        }
    }

    #[cfg(feature = "tls-connections")]
    fn create_ssl_context(listener: TcpListener) -> ServerSsl {
        let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
        builder.set_verify(SslVerifyMode::NONE);
        builder.set_read_ahead(true);
        builder
            .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
            .expect("Cannot find certificate file");
        builder
            .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
            .expect("Cannot find private key file");
        builder.set_options(SslOptions::NO_COMPRESSION);
        builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
        ServerSsl {
            listener,
            context: builder.build(),
        }
    }

    pub fn build(self) -> NetworkLayer {
        let server = HwServer::new(self.clients_capacity, self.rooms_capacity);
        let clients = Slab::with_capacity(self.clients_capacity);
        let pending = HashSet::with_capacity(2 * self.clients_capacity);
        let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
        let timer = timer::Builder::default().build();

        NetworkLayer {
            listener: self.listener.expect("No listener provided"),
            server,
            clients,
            pending,
            pending_cache,
            #[cfg(feature = "tls-connections")]
            ssl: Self::create_ssl_context(
                self.secure_listener.expect("No secure listener provided"),
            ),
            #[cfg(feature = "official-server")]
            io: IoLayer::new(),
            timer,
        }
    }
}