tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs
changeset 15793 96443d9b48c9
parent 15791 2528e3508bf4
equal deleted inserted replaced
15792:191e51179d1b 15793:96443d9b48c9
   254                     QueueBindOptions::default(),
   254                     QueueBindOptions::default(),
   255                     FieldTable::default(),
   255                     FieldTable::default(),
   256                 )
   256                 )
   257                 .await?;
   257                 .await?;
   258 
   258 
       
   259             sub_channel
       
   260                 .queue_bind(
       
   261                     queue.name().as_str(),
       
   262                     "irc",
       
   263                     "msg.hedgewars",
       
   264                     QueueBindOptions::default(),
       
   265                     FieldTable::default(),
       
   266                 )
       
   267                 .await?;
       
   268 
   259             let mut subscriber = sub_channel
   269             let mut subscriber = sub_channel
   260                 .basic_consume(
   270                 .basic_consume(
   261                     queue.name().as_str(),
   271                     queue.name().as_str(),
   262                     &"",
   272                     &"",
   263                     BasicConsumeOptions::default(),
   273                     BasicConsumeOptions::default(),
   265                 )
   275                 )
   266                 .await?;
   276                 .await?;
   267 
   277 
   268             vs.load(args[2].as_str())?;
   278             vs.load(args[2].as_str())?;
   269 
   279 
       
   280             let mut buffer = Vec::new();
       
   281 
   270             while let Some(amqp_message) = subscriber.next().await {
   282             while let Some(amqp_message) = subscriber.next().await {
   271                 let (_, delivery) = amqp_message.expect("error in consumer");
   283                 let (_, delivery) = amqp_message.expect("error in consumer");
   272                 delivery.ack(BasicAckOptions::default()).await?;
   284                 delivery.ack(BasicAckOptions::default()).await?;
   273 
   285 
   274                 let chat_message = String::from_utf8(delivery.data)?;
   286                 if delivery.routing_key.as_str() == "msg.hedgewars" {
   275                 if let Some((_who, seed)) = chat_message.split_once('\n') {
   287                     let chat_message = String::from_utf8_lossy(&delivery.data);
   276                     let input_sample = &format!("\n{}", seed);
   288                     if let Some((_who, message)) = chat_message.split_once('\n') {
       
   289                         buffer.push('\n');
       
   290                         buffer.extend(message.chars());
       
   291                         if buffer.len() >= BLOCK_SIZE as usize {
       
   292                             let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
       
   293                         }
       
   294                     }
       
   295                 } else {
       
   296                     let chat_message = String::from_utf8_lossy(&delivery.data);
       
   297                     let seed = chat_message.split_once('\n').map(|(_, s)| s).unwrap_or("");
       
   298                     buffer.push('\n');
       
   299                     buffer.extend(seed.chars());
       
   300 
       
   301                     if buffer.len() >= BLOCK_SIZE as usize {
       
   302                         let _ = buffer.drain(0..=buffer.len() - BLOCK_SIZE as usize);
       
   303                     }
       
   304 
   277                     let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
   305                     let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
   278                     for (idx, c) in input_sample.chars().rev().enumerate() {
   306                     for (idx, c) in buffer.iter().rev().enumerate() {
   279                         let idx = idx as i64;
       
   280                         if idx >= BLOCK_SIZE {
       
   281                             break;
       
   282                         }
       
   283                         let _filled = input
   307                         let _filled = input
   284                             .i((0, BLOCK_SIZE - 1 - idx))
   308                             .i((0, BLOCK_SIZE - 1 - idx as i64))
   285                             .fill_(data.char_to_label(c).unwrap_or(0) as i64);
   309                             .fill_(data.char_to_label(*c).unwrap_or(0) as i64);
   286                     }
   310                     }
   287 
   311 
   288                     let proceeded_message = &sample(&data, &gpt, input);
   312                     let proceeded_message = &sample(&data, &gpt, input);
   289                     let final_message = proceeded_message
   313                     let final_message = proceeded_message
   290                         .split_once('\n')
   314                         .split_once('\n')