rust/hedgewars-server/src/server/database.rs
author unC0Rr
Fri, 19 Apr 2024 17:45:05 +0200
changeset 16011 52b51d92e88d
parent 15937 e514ceb5e7d6
permissions -rw-r--r--
Fix build

use mysql_async::{self, from_row_opt, params, prelude::*, Pool};
use sha1::{Digest, Sha1};
use tokio::sync::mpsc::{channel, Receiver, Sender};

use crate::handlers::{AccountInfo, Sha1Digest};

const CHECK_ACCOUNT_EXISTS_QUERY: &str =
    r"SELECT 1 FROM users WHERE users.name = :username LIMIT 1";

const GET_ACCOUNT_QUERY: &str = r"SELECT CASE WHEN users.status = 1 THEN users.pass ELSE '' END,
     (SELECT COUNT(users_roles.rid) FROM users_roles WHERE users.uid = users_roles.uid AND users_roles.rid = 3),
     (SELECT COUNT(users_roles.rid) FROM users_roles WHERE users.uid = users_roles.uid AND users_roles.rid = 13)
     FROM users WHERE users.name = :username";

const STORE_STATS_QUERY: &str = r"INSERT INTO gameserver_stats
      (players, rooms, last_update)
      VALUES
      (:players, :rooms, UNIX_TIMESTAMP())";

const GET_REPLAY_NAME_QUERY: &str = r"SELECT filename FROM achievements WHERE id = :id";

pub struct ServerStatistics {
    rooms: u32,
    players: u32,
}

pub struct Achievements {}

pub enum DatabaseQuery {
    CheckRegistered {
        nick: String,
    },
    GetAccount {
        nick: String,
        protocol: u16,
        password_hash: String,
        client_salt: String,
        server_salt: String,
    },
    GetCheckerAccount {
        nick: String,
        password: String,
    },
    GetReplayFilename {
        id: u32,
    },
}

pub enum DatabaseResponse {
    AccountRegistered(bool),
    Account(Option<AccountInfo>),
    CheckerAccount { is_registered: bool },
}

pub struct Database {
    pool: Pool,
    query_rx: Receiver<DatabaseQuery>,
    response_tx: Sender<DatabaseResponse>,
}

impl Database {
    pub fn new(url: &str) -> Self {
        let (query_tx, query_rx) = channel(32);
        let (response_tx, response_rx) = channel(32);
        Self {
            pool: Pool::new(url),
            query_rx,
            response_tx,
        }
    }

    pub async fn run(&mut self) {
        use DatabaseResponse::*;
        loop {
            let query = self.query_rx.recv().await;
            if let Some(query) = query {
                match query {
                    DatabaseQuery::CheckRegistered { nick } => {
                        let is_registered = self.get_is_registered(&nick).await.unwrap_or(false);
                        self.response_tx
                            .send(AccountRegistered(is_registered))
                            .await;
                    }
                    DatabaseQuery::GetAccount {
                        nick,
                        protocol,
                        password_hash,
                        client_salt,
                        server_salt,
                    } => {
                        let account = self
                            .get_account(
                                &nick,
                                protocol,
                                &password_hash,
                                &client_salt,
                                &server_salt,
                            )
                            .await
                            .unwrap_or(None);
                        self.response_tx.send(Account(account)).await;
                    }
                    DatabaseQuery::GetCheckerAccount { nick, password } => {
                        let is_registered = self
                            .get_checker_account(&nick, &password)
                            .await
                            .unwrap_or(false);
                        self.response_tx
                            .send(CheckerAccount { is_registered })
                            .await;
                    }
                    DatabaseQuery::GetReplayFilename { id } => {
                        let filename = self.get_replay_name(id).await;
                    }
                };
            } else {
                break;
            }
        }
    }

    pub async fn get_is_registered(&mut self, nick: &str) -> mysql_async::Result<bool> {
        let mut connection = self.pool.get_conn().await?;
        let result = CHECK_ACCOUNT_EXISTS_QUERY
            .with(params! { "username" => nick })
            .first::<u32, _>(&mut connection)
            .await?;
        Ok(!result.is_some())
    }

    pub async fn get_account(
        &mut self,
        nick: &str,
        protocol: u16,
        password_hash: &str,
        client_salt: &str,
        server_salt: &str,
    ) -> mysql_async::Result<Option<AccountInfo>> {
        let mut connection = self.pool.get_conn().await?;
        if let Some((mut password, is_admin, is_contributor)) = GET_ACCOUNT_QUERY
            .with(params! { "username" => nick })
            .first::<(String, i32, i32), _>(&mut connection)
            .await?
        {
            let client_hash = get_hash(protocol, &password, &client_salt, &server_salt);
            let server_hash = get_hash(protocol, &password, &server_salt, &client_salt);
            password.replace_range(.., "🦔🦔🦔🦔🦔🦔🦔🦔");

            if client_hash == password_hash {
                Ok(Some(AccountInfo {
                    is_registered: true,
                    is_admin: is_admin == 1,
                    is_contributor: is_contributor == 1,
                    server_hash,
                }))
            } else {
                Ok(None)
            }
        } else {
            Ok(None)
        }
    }

    pub async fn get_checker_account(
        &mut self,
        nick: &str,
        checker_password: &str,
    ) -> mysql_async::Result<bool> {
        let mut connection = self.pool.get_conn().await?;
        if let Some((password, _, _)) = GET_ACCOUNT_QUERY
            .with(params! { "username" => nick })
            .first::<(String, i32, i32), _>(&mut connection)
            .await?
        {
            Ok(checker_password == password)
        } else {
            Ok(false)
        }
    }

    pub async fn store_stats(&mut self, stats: &ServerStatistics) -> mysql_async::Result<()> {
        let mut connection = self.pool.get_conn().await?;
        STORE_STATS_QUERY
            .with(params! {
                "players" => stats.players,
                "rooms" => stats.rooms,
            })
            .ignore(&mut connection)
            .await
    }

    pub async fn store_achievements(
        &mut self,
        achievements: &Achievements,
    ) -> mysql_async::Result<()> {
        Ok(())
    }

    pub async fn get_replay_name(&mut self, replay_id: u32) -> mysql_async::Result<Option<String>> {
        let mut connection = self.pool.get_conn().await?;
        GET_REPLAY_NAME_QUERY
            .with(params! { "id" => replay_id })
            .first::<String, _>(&mut connection)
            .await
    }
}

fn get_hash(protocol_number: u16, web_password: &str, salt1: &str, salt2: &str) -> Sha1Digest {
    let data = format!(
        "{}{}{}{}{}",
        salt1, salt2, web_password, protocol_number, "!hedgewars"
    );

    let mut sha1 = Sha1::new();
    sha1.update(&data);
    Sha1Digest::new(sha1.finalize().try_into().unwrap())
}