project_files/frontlib/net/netbase.c
author unC0Rr
Tue, 05 Sep 2023 17:01:46 +0200
branchtransitional_engine
changeset 16036 96d0e6149d3d
parent 10017 de822cd3df3a
permissions -rw-r--r--
Fix type

/*
 * Hedgewars, a free turn based strategy game
 * Copyright (C) 2012 Simeon Maxein <smaxein@googlemail.com>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 */

#include "netbase.h"
#include "../util/buffer.h"
#include "../util/logging.h"
#include "../util/util.h"
#include "../socket.h"

#include <string.h>
#include <stdbool.h>
#include <stdlib.h>
#include <stdio.h>

#define NET_READBUFFER_LIMIT (1024*1024)

struct _flib_netbase {
    flib_vector *readBuffer;
    flib_tcpsocket *sock;
};

flib_netbase *flib_netbase_create(const char *server, uint16_t port) {
    if(log_badargs_if2(server==NULL, port==0)) {
        return NULL;
    }

    flib_netbase *result = NULL;
    flib_netbase *newNet =  flib_calloc(1, sizeof(flib_netbase));

    if(newNet) {
        newNet->readBuffer = flib_vector_create();
        newNet->sock = flib_socket_connect(server, port);
        if(newNet->readBuffer && newNet->sock) {
            flib_log_i("Connected to server %s:%u", server, (unsigned)port);
            result = newNet;
            newNet = NULL;
        }
    }
    flib_netbase_destroy(newNet);

    return result;
}

void flib_netbase_destroy(flib_netbase *net) {
    if(net) {
        flib_socket_close(net->sock);
        flib_vector_destroy(net->readBuffer);
        free(net);
    }
}

bool flib_netbase_connected(flib_netbase *net) {
    if(!log_badargs_if(net==NULL) && net->sock) {
        return true;
    }
    return false;
}

/**
 * Parses and returns a message, and removes it from the vector.
 */
static flib_netmsg *parseMessage(flib_vector *vec) {
    const uint8_t *partStart = flib_vector_data(vec);
    const uint8_t *end = partStart+flib_vector_size(vec);
    flib_netmsg *result = flib_netmsg_create();
    if(!result) {
        return NULL;
    }

    while(1) {
        const uint8_t *partEnd = memchr(partStart, '\n', end-partStart);
        if(!partEnd) {
            // message incomplete
            flib_netmsg_destroy(result);
            return NULL;
        } else if(partEnd-partStart == 0) {
            // Zero-length part, message end marker. Remove the message from the vector.
            uint8_t *vectorStart = flib_vector_data(vec);
            size_t msgLen = partEnd+1-vectorStart;
            memmove(vectorStart, partEnd+1, flib_vector_size(vec)-msgLen);
            flib_vector_resize(vec, flib_vector_size(vec)-msgLen);
            return result;
        } else {
            if(flib_netmsg_append_part(result, partStart, partEnd-partStart)) {
                flib_netmsg_destroy(result);
                return NULL;
            }
            partStart = partEnd+1; // Skip the '\n'
        }
    }
    return NULL; // Never reached
}

/**
 * Receive some bytes and add them to the buffer.
 * Returns the number of bytes received.
 * Automatically closes the socket if an error occurs
 * and sets sock=NULL.
 */
static int receiveToBuffer(flib_netbase *net) {
    uint8_t buffer[256];
    if(!net->sock) {
        return 0;
    } else if(flib_vector_size(net->readBuffer) > NET_READBUFFER_LIMIT) {
        flib_log_e("Net connection closed: Net message too big");
        flib_socket_close(net->sock);
        net->sock = NULL;
        return 0;
    } else {
        int size = flib_socket_nbrecv(net->sock, buffer, sizeof(buffer));
        if(size>=0 && !flib_vector_append(net->readBuffer, buffer, size)) {
            return size;
        } else {
            flib_socket_close(net->sock);
            net->sock = NULL;
            return 0;
        }
    }
}

flib_netmsg *flib_netbase_recv_message(flib_netbase *net) {
    if(log_badargs_if(net==NULL)) {
        return NULL;
    }

    flib_netmsg *msg;
    while(!(msg=parseMessage(net->readBuffer))
            && receiveToBuffer(net)) {}

    if(msg) {
        return msg;
    } else if(!net->sock && flib_vector_size(net->readBuffer)>0) {
        // Connection is down and we didn't get a complete message, just flush the rest.
        flib_vector_resize(net->readBuffer, 0);
    }
    return NULL;
}

static void logSentMsg(const uint8_t *data, size_t len) {
    if(flib_log_isActive(FLIB_LOGLEVEL_DEBUG)) {
        flib_log_d("[NET OUT][%03u]%*.*s",(unsigned)len, (unsigned)len, (unsigned)len, data);
    }
}

int flib_netbase_send_raw(flib_netbase *net, const void *data, size_t len) {
    if(log_badargs_if2(net==NULL, data==NULL && len>0)) {
        return -1;
    }
    if(!net->sock) {
        flib_log_w("flib_netbase_send_raw: Not connected.");
        return -1;
    }

    if(flib_socket_send(net->sock, data, len) == len) {
        logSentMsg(data, len);
        return 0;
    } else {
        flib_log_w("Failed or incomplete write: net connection lost.");
        flib_socket_close(net->sock);
        net->sock = NULL;
        return -1;
    }
}

int flib_netbase_send_message(flib_netbase *net, const flib_netmsg *msg) {
    if(log_badargs_if2(net==NULL, msg==NULL)) {
        return -1;
    }

    size_t totalSize = 0;
    for(int i=0; i<msg->partCount; i++) {
        totalSize += strlen(msg->parts[i]) + 1;
    }
    totalSize++; // Last part ends in two '\n' instead of one

    uint8_t *buffer = flib_malloc(totalSize);
    if(!buffer) {
        return -1;
    }
    size_t pos = 0;
    for(int i=0; i<msg->partCount; i++) {
        size_t partsize = strlen(msg->parts[i]);
        memcpy(buffer+pos, msg->parts[i], partsize);
        pos += partsize;
        buffer[pos++] = '\n';
    }
    buffer[pos++] = '\n';
    return flib_netbase_send_raw(net, buffer, pos);
}

int flib_netbase_sendf(flib_netbase *net, const char *format, ...) {
    int result = -1;
    if(!log_badargs_if2(net==NULL, format==NULL)) {
        va_list argp;
        va_start(argp, format);
        char *buffer = flib_vasprintf(format, argp);
        if(buffer) {
            result = flib_netbase_send_raw(net, buffer, strlen(buffer));
        }
        free(buffer);
        va_end(argp);
    }
    return result;
}

flib_netmsg *flib_netmsg_create() {
    flib_netmsg *result = flib_calloc(1, sizeof(flib_netmsg));
    if(result) {
        result->partCount = 0;
        result->parts = NULL;
        return result;
    } else {
        return NULL;
    }
}

void flib_netmsg_destroy(flib_netmsg *msg) {
    if(msg) {
        for(int i=0; i<msg->partCount; i++) {
            free(msg->parts[i]);
        }
        free(msg->parts);
        free(msg);
    }
}

int flib_netmsg_append_part(flib_netmsg *msg, const void *part, size_t partlen) {
    int result = -1;
    if(!log_badargs_if2(msg==NULL, part==NULL && partlen>0)) {
        char **newParts = realloc(msg->parts, (msg->partCount+1)*sizeof(*msg->parts));
        if(newParts) {
            msg->parts = newParts;
            msg->parts[msg->partCount] = flib_malloc(partlen+1);
            if(msg->parts[msg->partCount]) {
                memcpy(msg->parts[msg->partCount], part, partlen);
                msg->parts[msg->partCount][partlen] = 0;
                msg->partCount++;
                result = 0;
            }
        }
    }
    return result;
}