diff -r ee84e417d8d0 -r a855f32ab3ca rust/hedgewars-checker/src/main.rs --- a/rust/hedgewars-checker/src/main.rs Wed Jun 30 00:18:53 2021 +0200 +++ b/rust/hedgewars-checker/src/main.rs Wed Jun 30 23:06:54 2021 +0200 @@ -1,32 +1,17 @@ +use anyhow::{bail, Result}; use argparse::{ArgumentParser, Store}; +use hedgewars_network_protocol::{ + messages::HwProtocolMessage as ClientMessage, messages::HwServerMessage::*, parser, +}; use ini::Ini; +use log::{debug, info, warn}; use netbuf::Buf; -use log::{debug, warn, info}; -use std::{ - io::Write, - net::TcpStream, - process::Command, - str::FromStr -}; - -type CheckError = Box; +use std::{io::Write, net::TcpStream, process::Command, str::FromStr}; -fn extract_packet(buf: &mut Buf) -> Option { - let packet_end = (&buf[..]).windows(2).position(|window| window == b"\n\n")?; - - let mut tail = buf.split_off(packet_end); - - std::mem::swap(&mut tail, buf); - - buf.consume(2); - - Some(tail) -} - -fn check(executable: &str, data_prefix: &str, buffer: &[u8]) -> Result>, CheckError> { +fn check(executable: &str, data_prefix: &str, buffer: &[String]) -> Result> { let mut replay = tempfile::NamedTempFile::new()?; - for line in buffer.split(|b| *b == '\n' as u8) { + for line in buffer.into_iter() { replay.write(&base64::decode(line)?)?; } @@ -57,31 +42,31 @@ loop { match engine_lines.next() { - Some(b"DRAW") => result.push(b"DRAW".to_vec()), + Some(b"DRAW") => result.push("DRAW".to_owned()), Some(b"WINNERS") => { - result.push(b"WINNERS".to_vec()); + result.push("WINNERS".to_owned()); let winners = engine_lines.next().unwrap(); let winners_num = u32::from_str(&String::from_utf8(winners.to_vec())?)?; - result.push(winners.to_vec()); + result.push(String::from_utf8(winners.to_vec())?); for _i in 0..winners_num { - result.push(engine_lines.next().unwrap().to_vec()); + result.push(String::from_utf8(engine_lines.next().unwrap().to_vec())?); } } Some(b"GHOST_POINTS") => { - result.push(b"GHOST_POINTS".to_vec()); + result.push("GHOST_POINTS".to_owned()); let points = engine_lines.next().unwrap(); let points_num = u32::from_str(&String::from_utf8(points.to_vec())?)? * 2; - result.push(points.to_vec()); + result.push(String::from_utf8(points.to_vec())?); for _i in 0..points_num { - result.push(engine_lines.next().unwrap().to_vec()); + result.push(String::from_utf8(engine_lines.next().unwrap().to_vec())?); } } Some(b"ACHIEVEMENT") => { - result.push(b"ACHIEVEMENT".to_vec()); + result.push("ACHIEVEMENT".to_owned()); for _i in 0..4 { - result.push(engine_lines.next().unwrap().to_vec()); + result.push(String::from_utf8(engine_lines.next().unwrap().to_vec())?); } } _ => break, @@ -91,17 +76,17 @@ if result.len() > 0 { Ok(result) } else { - Err("no data from engine".into()) + bail!("no data from engine") } } fn connect_and_run( username: &str, password: &str, - protocol_number: u32, + protocol_number: u16, executable: &str, data_prefix: &str, -) -> Result<(), CheckError> { +) -> Result<()> { info!("Connecting..."); let mut stream = TcpStream::connect("hedgewars.org:46631")?; @@ -112,70 +97,93 @@ loop { buf.read_from(&mut stream)?; - while let Some(msg) = extract_packet(&mut buf) { - if msg[..].starts_with(b"CONNECTED") { - info!("Connected"); - let p = format!( - "CHECKER\n{}\n{}\n{}\n\n", - protocol_number, username, password - ); - stream.write(p.as_bytes())?; - } else if msg[..].starts_with(b"PING") { - stream.write(b"PONG\n\n")?; - } else if msg[..].starts_with(b"LOGONPASSED") { - info!("Logged in"); - stream.write(b"READY\n\n")?; - } else if msg[..].starts_with(b"REPLAY") { - info!("Got a replay"); - match check(executable, data_prefix, &msg[7..]) { - Ok(result) => { - info!("Checked"); - debug!( - "Check result: [{}]", - String::from_utf8_lossy(&result.join(&(',' as u8))) - ); + while let Ok((tail, msg)) = parser::server_message(buf.as_ref()) { + buf.consume(buf.len() - tail.len()); - stream.write(b"CHECKED\nOK\n")?; - stream.write(&result.join(&('\n' as u8)))?; - stream.write(b"\n\nREADY\n\n")?; - } - Err(e) => { - info!("Check failed: {:?}", e); - stream.write(b"CHECKED\nFAIL\nerror\n\nREADY\n\n")?; + match msg { + Connected(_, _) => { + info!("Connected"); + stream.write( + ClientMessage::Checker( + protocol_number, + username.to_owned(), + password.to_owned(), + ) + .to_raw_protocol() + .as_bytes(), + )?; + } + Ping => { + stream.write(ClientMessage::Pong.to_raw_protocol().as_bytes())?; + } + LogonPassed => { + stream.write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?; + } + Replay(lines) => { + info!("Got a replay"); + match check(executable, data_prefix, &lines) { + Ok(result) => { + info!("Checked"); + debug!("Check result: [{:?}]", result); + + stream.write( + ClientMessage::CheckedOk(result) + .to_raw_protocol() + .as_bytes(), + )?; + stream + .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?; + } + Err(e) => { + info!("Check failed: {:?}", e); + stream.write( + ClientMessage::CheckedFail("error".to_owned()) + .to_raw_protocol() + .as_bytes(), + )?; + stream + .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?; + } } } - } else if msg[..].starts_with(b"BYE") { - warn!("Received BYE: {}", String::from_utf8_lossy(&msg[..])); - return Ok(()); - } else if msg[..].starts_with(b"CHAT") { - let body = String::from_utf8_lossy(&msg[5..]); - let mut l = body.lines(); - info!("Chat [{}]: {}", l.next().unwrap(), l.next().unwrap()); - } else if msg[..].starts_with(b"ROOM") { - let body = String::from_utf8_lossy(&msg[5..]); - let mut l = body.lines(); - if let Some(action) = l.next() { - if action == "ADD" { - info!("Room added: {}", l.skip(1).next().unwrap()); + Bye(message) => { + warn!("Received BYE: {}", message); + return Ok(()); + } + ChatMsg { nick, msg } => { + info!("Chat [{}]: {}", nick, msg); + } + RoomAdd(fields) => { + let mut l = fields.into_iter(); + info!("Room added: {}", l.skip(1).next().unwrap()); + } + RoomUpdated(name, fields) => { + let mut l = fields.into_iter(); + let new_name = l.skip(1).next().unwrap(); + + if (name != new_name) { + info!("Room renamed: {}", new_name); } } - } else if msg[..].starts_with(b"ERROR") { - warn!("Received ERROR: {}", String::from_utf8_lossy(&msg[..])); - return Ok(()); - } else { - warn!( - "Unknown protocol command: {}", - String::from_utf8_lossy(&msg[..]) - ) + RoomRemove(_) => { + // ignore + } + Error(message) => { + warn!("Received ERROR: {}", message); + return Ok(()); + } + something => { + warn!("Unexpected protocol command: {:?}", something) + } } } } } -fn get_protocol_number(executable: &str) -> std::io::Result { +fn get_protocol_number(executable: &str) -> std::io::Result { let output = Command::new(executable).arg("--protocol").output()?; - Ok(u32::from_str(&String::from_utf8(output.stdout).unwrap().trim()).unwrap_or(55)) + Ok(u16::from_str(&String::from_utf8(output.stdout).unwrap().trim()).unwrap_or(55)) } fn main() { @@ -214,23 +222,3 @@ connect_and_run(&username, &password, protocol_number, &exe, &prefix).unwrap(); } - -#[cfg(test)] -#[test] -fn test() { - let mut buf = Buf::new(); - buf.extend(b"Hell"); - if let Some(_) = extract_packet(&mut buf) { - assert!(false) - } - - buf.extend(b"o\n\nWorld"); - - let packet2 = extract_packet(&mut buf).unwrap(); - assert_eq!(&buf[..], b"World"); - assert_eq!(&packet2[..], b"Hello"); - - if let Some(_) = extract_packet(&mut buf) { - assert!(false) - } -}