# HG changeset patch # User unc0rr # Date 1623782788 -7200 # Node ID 96443d9b48c9a82240e6aa4b778e5c796009f7e7 # Parent 191e51179d1b38fcbaca88d849d459223315dc9c Update mingpt plugin to include recent chat history in the seed diff -r 191e51179d1b -r 96443d9b48c9 tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs --- a/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs Tue Jun 15 20:45:46 2021 +0200 +++ b/tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs Tue Jun 15 20:46:28 2021 +0200 @@ -256,6 +256,16 @@ ) .await?; + sub_channel + .queue_bind( + queue.name().as_str(), + "irc", + "msg.hedgewars", + QueueBindOptions::default(), + FieldTable::default(), + ) + .await?; + let mut subscriber = sub_channel .basic_consume( queue.name().as_str(), @@ -267,22 +277,36 @@ vs.load(args[2].as_str())?; + let mut buffer = Vec::new(); + while let Some(amqp_message) = subscriber.next().await { let (_, delivery) = amqp_message.expect("error in consumer"); delivery.ack(BasicAckOptions::default()).await?; - let chat_message = String::from_utf8(delivery.data)?; - if let Some((_who, seed)) = chat_message.split_once('\n') { - let input_sample = &format!("\n{}", seed); + if delivery.routing_key.as_str() == "msg.hedgewars" { + let chat_message = String::from_utf8_lossy(&delivery.data); + if let Some((_who, message)) = chat_message.split_once('\n') { + buffer.push('\n'); + buffer.extend(message.chars()); + if buffer.len() >= BLOCK_SIZE as usize { + let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); + } + } + } else { + let chat_message = String::from_utf8_lossy(&delivery.data); + let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or(""); + buffer.push('\n'); + buffer.extend(seed.chars()); + + if buffer.len() >= BLOCK_SIZE as usize { + let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); + } + let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); - for (idx, c) in input_sample.chars().rev().enumerate() { - let idx = idx as i64; - if idx >= BLOCK_SIZE { - break; - } + for (idx, c) in buffer.iter().rev().enumerate() { let _filled = input - .i((0, BLOCK_SIZE - 1 - idx)) - .fill_(data.char_to_label(c).unwrap_or(0) as i64); + .i((0, BLOCK_SIZE - 1 - idx as i64)) + .fill_(data.char_to_label(*c).unwrap_or(0) as i64); } let proceeded_message = &sample(&data, &gpt, input);