254 QueueBindOptions::default(), |
254 QueueBindOptions::default(), |
255 FieldTable::default(), |
255 FieldTable::default(), |
256 ) |
256 ) |
257 .await?; |
257 .await?; |
258 |
258 |
|
259 sub_channel |
|
260 .queue_bind( |
|
261 queue.name().as_str(), |
|
262 "irc", |
|
263 "msg.hedgewars", |
|
264 QueueBindOptions::default(), |
|
265 FieldTable::default(), |
|
266 ) |
|
267 .await?; |
|
268 |
259 let mut subscriber = sub_channel |
269 let mut subscriber = sub_channel |
260 .basic_consume( |
270 .basic_consume( |
261 queue.name().as_str(), |
271 queue.name().as_str(), |
262 &"", |
272 &"", |
263 BasicConsumeOptions::default(), |
273 BasicConsumeOptions::default(), |
265 ) |
275 ) |
266 .await?; |
276 .await?; |
267 |
277 |
268 vs.load(args[2].as_str())?; |
278 vs.load(args[2].as_str())?; |
269 |
279 |
|
280 let mut buffer = Vec::new(); |
|
281 |
270 while let Some(amqp_message) = subscriber.next().await { |
282 while let Some(amqp_message) = subscriber.next().await { |
271 let (_, delivery) = amqp_message.expect("error in consumer"); |
283 let (_, delivery) = amqp_message.expect("error in consumer"); |
272 delivery.ack(BasicAckOptions::default()).await?; |
284 delivery.ack(BasicAckOptions::default()).await?; |
273 |
285 |
274 let chat_message = String::from_utf8(delivery.data)?; |
286 if delivery.routing_key.as_str() == "msg.hedgewars" { |
275 if let Some((_who, seed)) = chat_message.split_once('\n') { |
287 let chat_message = String::from_utf8_lossy(&delivery.data); |
276 let input_sample = &format!("\n{}", seed); |
288 if let Some((_who, message)) = chat_message.split_once('\n') { |
|
289 buffer.push('\n'); |
|
290 buffer.extend(message.chars()); |
|
291 if buffer.len() >= BLOCK_SIZE as usize { |
|
292 let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); |
|
293 } |
|
294 } |
|
295 } else { |
|
296 let chat_message = String::from_utf8_lossy(&delivery.data); |
|
297 let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or(""); |
|
298 buffer.push('\n'); |
|
299 buffer.extend(seed.chars()); |
|
300 |
|
301 if buffer.len() >= BLOCK_SIZE as usize { |
|
302 let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize); |
|
303 } |
|
304 |
277 let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
305 let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
278 for (idx, c) in input_sample.chars().rev().enumerate() { |
306 for (idx, c) in buffer.iter().rev().enumerate() { |
279 let idx = idx as i64; |
|
280 if idx >= BLOCK_SIZE { |
|
281 break; |
|
282 } |
|
283 let _filled = input |
307 let _filled = input |
284 .i((0, BLOCK_SIZE - 1 - idx)) |
308 .i((0, BLOCK_SIZE - 1 - idx as i64)) |
285 .fill_(data.char_to_label(c).unwrap_or(0) as i64); |
309 .fill_(data.char_to_label(*c).unwrap_or(0) as i64); |
286 } |
310 } |
287 |
311 |
288 let proceeded_message = &sample(&data, &gpt, input); |
312 let proceeded_message = &sample(&data, &gpt, input); |
289 let final_message = proceeded_message |
313 let final_message = proceeded_message |
290 .split_once('\n') |
314 .split_once('\n') |