tools/hw2irc/net/ercatec/hw/ProtocolConnection.java
author sheepluva
Wed, 26 May 2021 16:32:43 -0400
changeset 15784 823cf18be1fc
permissions -rw-r--r--
Hedgewars lobby to IRC proxy

/*
 * Java net client for Hedgewars, a free turn based strategy game
 * Copyright (c) 2011 Richard Karolyi <sheepluva@ercatec.net>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2 of the License
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
 */

package net.ercatec.hw;

import java.lang.*;
import java.lang.IllegalArgumentException;
import java.lang.Runnable;
import java.lang.Thread;
import java.io.*;
import java.net.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;
import java.util.Vector;
// for auth
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.security.NoSuchAlgorithmException;

public final class ProtocolConnection implements Runnable
{
    private static final String DEFAULT_HOST = "netserver.hedgewars.org";
    private static final int DEFAULT_PORT = 46631;
    private static final String PROTOCOL_VERSION = "53";

    private final Socket socket;
    private BufferedReader fromSvr;
    private PrintWriter toSvr;
    private final INetClient netClient;
    private boolean quit;
    private boolean debug;

    private final String host;
    private final int port;

    private String nick;

    public ProtocolConnection(INetClient netClient) throws Exception {
        this(netClient, DEFAULT_HOST);
    }

    public ProtocolConnection(INetClient netClient, String host) throws Exception {
        this(netClient, host, DEFAULT_PORT);
    }

    public ProtocolConnection(INetClient netClient, String host, int port) throws Exception {
        this.netClient = netClient;
        this.host = host;
        this.port = port;
        this.nick = nick = "";
        this.quit = false;

        fromSvr = null;
        toSvr = null;

        try {
            socket = new Socket(host, port);
            fromSvr = new BufferedReader(new InputStreamReader(socket.getInputStream()));
            toSvr = new PrintWriter(socket.getOutputStream(), true);
        }
        catch(Exception ex) {
            throw ex;
        }

        ProtocolMessage firstMsg = processNextMessage();
        if (firstMsg.getType() != ProtocolMessage.Type.CONNECTED) {
            closeConnection();
            throw new Exception("First Message wasn't CONNECTED.");
        }

    }

    public void run() {

        try {
            while (!quit) {
                processNextMessage();
            }
        }
        catch(Exception ex) {
            netClient.logError("FATAL: Run loop died unexpectedly!");
            ex.printStackTrace();
            handleConnectionLoss();
        }

        // only gets here when connection was closed
    }

    public void processMessages() {
        processMessages(false);
    }

    public Thread processMessages(boolean inNewThread)
    {
        if (inNewThread)
            return new Thread(this);

        run();
        return null;
    }

    public void processNextClientFlagsMessages()
    {
        while (!quit) {
            if (!processNextMessage(true).isValid)
                break;
        }
    }

    private ProtocolMessage processNextMessage() {
        return processNextMessage(false);
    }

    private void handleConnectionLoss() {
        closeConnection();
        netClient.onConnectionLoss();
    }

    public void close() {
        this.closeConnection();
    }

    private synchronized void closeConnection() {
        if (quit)
            return;

        quit = true;
        try {
            if (fromSvr != null)
                fromSvr.close();
        } catch(Exception ex) {};
        try {
            if (toSvr != null)
                toSvr.close();
        } catch(Exception ex) {};
        try {
            socket.close();
        } catch(Exception ex) {};
    }

    private String resumeLine = "";

    private ProtocolMessage processNextMessage(boolean onlyIfClientFlags)
    {
        String line;
        final List<String> parts = new ArrayList<String>(32);

        while (!quit) {

            if (!resumeLine.isEmpty()) {
                    line = resumeLine;
                    resumeLine = "";
                }
            else {
                try {
                    line = fromSvr.readLine();
                    
                    if (onlyIfClientFlags && (parts.size() == 0)
                        && !line.equals("CLIENT_FLAGS")) {
                            resumeLine = line;
                            // return invalid message
                            return new ProtocolMessage();
                        }
                }
                catch(Exception whoops) {
                    handleConnectionLoss();
                    break;
                }
            }

            if (line == null) {
                handleConnectionLoss();
                // return invalid message
                return new ProtocolMessage();
            }

            if (!quit && line.isEmpty()) {

                if (parts.size() > 0) {

                    ProtocolMessage msg = new ProtocolMessage(parts);

                    netClient.logDebug("Server: " + msg.toString());

                    if (!msg.isValid()) {
                        netClient.onMalformedMessage(msg.toString());
                        if (msg.getType() != ProtocolMessage.Type.BYE)
                            continue;
                    }

                    final String[] args = msg.getArguments();
                    netClient.sanitizeInputs(args);


                    final int argc = args.length;

                    try {
                        switch (msg.getType()) {

                            case PING:
                                netClient.onPing();
                                break;

                            case LOBBY__JOINED:
                                try {
                                    assertAuthNotIncomplete();
                                }
                                catch (Exception ex) {
                                    disconnect();
                                    netClient.onDisconnect(ex.getMessage());
                                }
                                netClient.onLobbyJoin(args);
                                break;

                            case LOBBY__LEFT:
                                netClient.onLobbyLeave(args[0], args[1]);
                                break;

                            case CLIENT_FLAGS:
                                String user;
                                final String flags = args[0];
                                if (flags.length() < 2) {
                                    //netClient.onMalformedMessage(msg.toString());
                                    break;
                                }
                                final char mode = flags.charAt(0);
                                if ((mode != '-') && (mode != '+')) {
                                    //netClient.onMalformedMessage(msg.toString());
                                    break;
                                }

                                final int l = flags.length();

                                for (int i = 1; i < l; i++) {
                                    // set flag type
                                    final INetClient.UserFlagType flag;
                                    // TODO support more flags
                                    switch (flags.charAt(i)) {
                                        case 'a':
                                            flag = INetClient.UserFlagType.ADMIN;
                                            break;
                                        case 'i':
                                            flag = INetClient.UserFlagType.INROOM;
                                            break;
                                        case 'u':
                                            flag = INetClient.UserFlagType.REGISTERED;
                                            break;
                                        default:
                                            flag = INetClient.UserFlagType.UNKNOWN;
                                            break;
                                        }

                                    for (int j = 1; j < args.length; j++) {
                                        netClient.onUserFlagChange(args[j], flag, mode=='+');
                                    }
                                }
                                break;

                            case CHAT:
                                netClient.onChat(args[0], args[1]);
                                break;

                            case INFO:
                                netClient.onUserInfo(args[0], args[1], args[2], args[3]);
                                break;

                            case PONG:
                                netClient.onPong();
                                break;

                            case NICK:
                                final String newNick = args[0];
                                if (!newNick.equals(this.nick)) {
                                    this.nick = newNick;
                                }
                                    netClient.onNickSet(this.nick);
                                sendMessage(new String[] { "PROTO", PROTOCOL_VERSION });
                                break;

                            case NOTICE:
                                // nickname collision
                                if (args[0].equals("0"))
                                    setNick(netClient.onNickCollision(this.nick));
                                break;

                            case ASKPASSWORD:
                                try {
                                    final String pwHash = netClient.onPasswordHashNeededForAuth();
                                    doAuthPart1(pwHash, args[0]);
                                }
                                catch (Exception ex) {
                                    disconnect();
                                    netClient.onDisconnect(ex.getMessage());
                                }
                                break;

                            case ROOMS:
                                final int nf = ProtocolMessage.ROOM_FIELD_COUNT;
                                for (int a = 0; a < argc; a += nf) {
                                    handleRoomInfo(args[a+1], Arrays.copyOfRange(args, a, a + nf));
                                }

                            case ROOM_ADD:
                                handleRoomInfo(args[1], args);
                                break;

                            case ROOM_DEL:
                                netClient.onRoomDel(args[0]);
                                break;

                            case ROOM_UPD:
                                handleRoomInfo(args[0], Arrays.copyOfRange(args, 1, args.length));
                                break;

                            case BYE:
                                closeConnection();
                                if (argc > 0)
                                    netClient.onDisconnect(args[0]);
                                else
                                    netClient.onDisconnect("");
                                break;

                            case SERVER_AUTH:
                                try {
                                    doAuthPart2(args[0]);
                                }
                                catch (Exception ex) {
                                    disconnect();
                                    netClient.onDisconnect(ex.getMessage());
                                }
                                break;
                        }
                        // end of message
                        return msg;
                    }
                    catch(IllegalArgumentException ex) {

                        netClient.logError("Illegal arguments! "
                            + ProtocolMessage.partsToString(parts)
                            + "caused: " + ex.getMessage());

                        return new ProtocolMessage();
                    }
                }
            }
            else
            {
                parts.add(line);
            }
        }

        netClient.logError("WARNING: Message wasn't parsed correctly: "
                            + ProtocolMessage.partsToString(parts));
        // return invalid message
        return new ProtocolMessage(); // never to be reached
    }

    private void handleRoomInfo(final String name, final String[] info) throws IllegalArgumentException
    {
        // TODO room flags enum array

        final int nUsers;
        final int nTeams;
        
        try {
            nUsers = Integer.parseInt(info[2]);
        }
        catch(IllegalArgumentException ex) {
            throw new IllegalArgumentException(
                "Player count is not an valid integer!",
                ex);
        }

        try {
            nTeams = Integer.parseInt(info[3]);
        }
        catch(IllegalArgumentException ex) {
            throw new IllegalArgumentException(
                "Team count is not an valid integer!",
                ex);
        }

        netClient.onRoomInfo(name, info[0], info[1], nUsers, nTeams,
                             info[4], info[5], info[6], info[7], info[8]);
    }

    private static final String AUTH_SALT = PROTOCOL_VERSION + "!hedgewars";
    private static final int PASSWORD_HASH_LENGTH = 32;
    public static final int SERVER_SALT_MIN_LENGTH = 16;
    private static final String AUTH_ALG = "SHA-1";
    private String serverAuthHash = "";

    private void assertAuthNotIncomplete() throws Exception {
        if (!serverAuthHash.isEmpty()) {
            netClient.logError("AUTH-ERROR: assertAuthNotIncomplete() found that authentication was not completed!");
            throw new Exception("Authentication was not finished properly!");
        }
        serverAuthHash = "";
    }

    private void doAuthPart2(final String serverAuthHash) throws Exception {
        if (!this.serverAuthHash.equals(serverAuthHash)) {
            netClient.logError("AUTH-ERROR: Server's authentication hash is incorrect!");
            throw new Exception("Server failed mutual authentication! (wrong hash provided by server)");
        }
        netClient.logDebug("Auth: Mutual authentication successful.");
        this.serverAuthHash = "";
    }

    private void doAuthPart1(final String pwHash, final String serverSalt) throws Exception {
        if ((pwHash == null) || pwHash.isEmpty()) {
            netClient.logDebug("AUTH: Password required, but no password hash was provided.");
            throw new Exception("Auth: Password needed, but none specified.");
        }
        if (pwHash.length() != PASSWORD_HASH_LENGTH) {
            netClient.logError("AUTH-ERROR: Your password hash has an unexpected length! Should be "
                               + PASSWORD_HASH_LENGTH + " but is " + pwHash.length()
                              );
            throw new Exception("Auth: Your password hash length seems wrong.");
        }
        if (serverSalt.length() < SERVER_SALT_MIN_LENGTH) {
            netClient.logError("AUTH-ERROR: Salt provided by server is too short! Should be at least "
                               + SERVER_SALT_MIN_LENGTH + " but is " + serverSalt.length()
                              );
            throw new Exception("Auth: Server violated authentication protocol! (auth salt too short)");
        }

        final MessageDigest sha1Digest;

        try {
             sha1Digest = MessageDigest.getInstance(AUTH_ALG);
        }
        catch(NoSuchAlgorithmException ex) {
            netClient.logError("AUTH-ERROR: Algorithm required for authentication ("
                                      + AUTH_ALG + ") not available!"
                                     );
            return;
        } 
        

        // generate 130 bit base32 encoded value
        // base32 = 5bits/char => 26 chars, which is more than min req
        final String clientSalt =
            new BigInteger(130, new SecureRandom()).toString(32);

        final String saltedPwHash =
            clientSalt + serverSalt + pwHash + AUTH_SALT;

        final String saltedPwHash2 =
            serverSalt + clientSalt + pwHash + AUTH_SALT;

        final String clientAuthHash =
            new BigInteger(1, sha1Digest.digest(saltedPwHash.getBytes("UTF-8"))).toString(16);

        serverAuthHash =
            new BigInteger(1, sha1Digest.digest(saltedPwHash2.getBytes("UTF-8"))).toString(16);

        sendMessage(new String[] { "PASSWORD", clientAuthHash, clientSalt });

/* When we got password hash, and server asked us for a password, perform mutual authentication:
 * at this point we have salt chosen by server
 * client sends client salt and hash of secret (password hash) salted with client salt, server salt,
 * and static salt (predefined string + protocol number)
 * server should respond with hash of the same set in different order.

    if(m_passwordHash.isEmpty() || m_serverSalt.isEmpty())
        return;

    QString hash = QCryptographicHash::hash(
                m_clientSalt.toAscii()
                .append(m_serverSalt.toAscii())
                .append(m_passwordHash)
                .append(cProtoVer->toAscii())
                .append("!hedgewars")
                , QCryptographicHash::Sha1).toHex();

    m_serverHash = QCryptographicHash::hash(
                m_serverSalt.toAscii()
                .append(m_clientSalt.toAscii())
                .append(m_passwordHash)
                .append(cProtoVer->toAscii())
                .append("!hedgewars")
                , QCryptographicHash::Sha1).toHex();

    RawSendNet(QString("PASSWORD%1%2%1%3").arg(delimiter).arg(hash).arg(m_clientSalt));

Server:  ("ASKPASSWORD", "5S4q9Dd0Qrn1PNsxymtRhupN") 
Client:  ("PASSWORD", "297a2b2f8ef83bcead4056b4df9313c27bb948af", "{cc82f4ca-f73c-469d-9ab7-9661bffeabd1}") 
Server:  ("SERVER_AUTH", "06ecc1cc23b2c9ebd177a110b149b945523752ae") 

 */
    }

    public void sendCommand(final String command)
    {
        String cmd = command;

        // don't execute empty commands
        if (cmd.length() < 1)
            return;

        // replace all newlines since they violate protocol
        cmd = cmd.replace('\n', ' ');

        // parameters are separated by one or more spaces.
        final String[] parts = cmd.split(" +");

        // command is always CAPS
        parts[0] = parts[0].toUpperCase();

        sendMessage(parts);
    }

    public void sendPing()
    {
        sendMessage("PING");
    }

    public void sendPong()
    {
        sendMessage("PONG");
    }

    private void sendMessage(final String msg)
    {
        sendMessage(new String[] { msg });
    }

    private void sendMessage(final String[] parts)
    {
        if (quit)
            return;

        netClient.logDebug("Client: " + messagePartsToString(parts));

        boolean malformed = false;
        String msg = "";

        for (final String part : parts)
        {
            msg += part + '\n';
            if (part.isEmpty() || (part.indexOf('\n') >= 0)) {
                malformed = true;
                break;
            }
        }

        if (malformed) {
            netClient.onMalformedMessage(messagePartsToString(parts));
            return;
        }

        try {
            toSvr.print(msg + '\n'); // don't use println, since we always want '\n'
            toSvr.flush();
        }
        catch(Exception ex) {
            netClient.logError("FATAL: Couldn't send message! " + ex.getMessage());
            ex.printStackTrace();
            handleConnectionLoss();
        }
    }

    private String messagePartsToString(String[] parts) {

        if (parts.length == 0)
            return "([empty message])";

        String result = "(\"" + parts[0] + '"';
        for (int i=1; i < parts.length; i++)
        {
            result += ", \"" + parts[i] + '"';
        }
        result += ')';

        return result;
    }

    public void disconnect() {
        sendMessage(new String[] { "QUIT", "Client quit" });
        closeConnection();
    }

    public void disconnect(final String reason) {
        sendMessage(new String[] { "QUIT", reason.isEmpty()?"-":reason });
        closeConnection();
    }

    public void sendChat(String message) {

        String[] lines = message.split("\n");

        for (String line : lines)
        {
            if (!message.trim().isEmpty())
                sendMessage(new String[] { "CHAT", line });
        }
    }

    public void joinRoom(final String roomName) {

        sendMessage(new String[] { "JOIN_ROOM", roomName });
    }

    public void leaveRoom(final String roomName) {

        sendMessage("PART");
    }

    public void requestInfo(final String user) {

        sendMessage(new String[] { "INFO", user });
    }

    public void setNick(final String nick) {

        this.nick = nick;
        sendMessage(new String[] { "NICK", nick });
    }

    public void kick(final String nick) {

        sendMessage(new String[] { "KICK", nick });
    }

    public void requestRoomsList() {

        sendMessage("LIST");
    }
}