tools/ubot-plugins/ubot-mingpt-plugin/src/main.rs
changeset 15791 2528e3508bf4
child 15793 96443d9b48c9
equal deleted inserted replaced
15790:efe4e3290870 15791:2528e3508bf4
       
     1 /* This example uses the tinyshakespeare dataset which can be downloaded at:
       
     2    https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
       
     3 
       
     4    This is mostly a rust port of https://github.com/karpathy/minGPT
       
     5 */
       
     6 
       
     7 extern crate tch;
       
     8 use anyhow::{bail, Result as AHResult};
       
     9 use std::{io, io::Write};
       
    10 use tch::data::TextData;
       
    11 use tch::nn::{ModuleT, OptimizerConfig};
       
    12 use tch::{nn, Device, IndexOp, Kind, Tensor};
       
    13 
       
    14 use futures::prelude::*;
       
    15 use lapin::{options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties};
       
    16 
       
    17 use tokio_amqp::*;
       
    18 
       
    19 const LEARNING_RATE: f64 = 0.0003;
       
    20 const BLOCK_SIZE: i64 = 128;
       
    21 const BATCH_SIZE: i64 = 64;
       
    22 const EPOCHS: i64 = 100;
       
    23 const SAMPLING_LEN: i64 = 512;
       
    24 
       
    25 #[derive(Debug, Copy, Clone)]
       
    26 struct Config {
       
    27     vocab_size: i64,
       
    28     n_embd: i64,
       
    29     n_head: i64,
       
    30     n_layer: i64,
       
    31     block_size: i64,
       
    32     attn_pdrop: f64,
       
    33     resid_pdrop: f64,
       
    34     embd_pdrop: f64,
       
    35 }
       
    36 
       
    37 // Weight decay only applies to the weight matrixes in the linear layers
       
    38 const NO_WEIGHT_DECAY_GROUP: usize = 0;
       
    39 const WEIGHT_DECAY_GROUP: usize = 1;
       
    40 
       
    41 // Custom linear layer so that different groups can be used for weight
       
    42 // and biases.
       
    43 #[derive(Debug)]
       
    44 struct Linear {
       
    45     pub ws: Tensor,
       
    46     pub bs: Tensor,
       
    47 }
       
    48 
       
    49 impl nn::Module for Linear {
       
    50     fn forward(&self, xs: &Tensor) -> Tensor {
       
    51         xs.matmul(&self.ws.tr()) + &self.bs
       
    52     }
       
    53 }
       
    54 
       
    55 fn linear(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear {
       
    56     let wd = vs.set_group(WEIGHT_DECAY_GROUP);
       
    57     let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP);
       
    58     Linear {
       
    59         ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02),
       
    60         bs: no_wd.zeros("bias", &[out_dim]),
       
    61     }
       
    62 }
       
    63 
       
    64 fn linear_no_bias(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear {
       
    65     let wd = vs.set_group(WEIGHT_DECAY_GROUP);
       
    66     let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP);
       
    67     Linear {
       
    68         ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02),
       
    69         bs: no_wd.zeros_no_train("bias", &[out_dim]),
       
    70     }
       
    71 }
       
    72 
       
    73 fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT {
       
    74     let key = linear(p / "key", cfg.n_embd, cfg.n_embd);
       
    75     let query = linear(p / "query", cfg.n_embd, cfg.n_embd);
       
    76     let value = linear(p / "value", cfg.n_embd, cfg.n_embd);
       
    77     let proj = linear(p / "proj", cfg.n_embd, cfg.n_embd);
       
    78     let mask_init =
       
    79         Tensor::ones(&[cfg.block_size, cfg.block_size], (Kind::Float, p.device())).tril(0);
       
    80     let mask_init = mask_init.view([1, 1, cfg.block_size, cfg.block_size]);
       
    81     // let mask = p.var_copy("mask", &mask_init);
       
    82     let mask = mask_init;
       
    83     nn::func_t(move |xs, train| {
       
    84         let (sz_b, sz_t, sz_c) = xs.size3().unwrap();
       
    85         let sizes = [sz_b, sz_t, cfg.n_head, sz_c / cfg.n_head];
       
    86         let k = xs.apply(&key).view(sizes).transpose(1, 2);
       
    87         let q = xs.apply(&query).view(sizes).transpose(1, 2);
       
    88         let v = xs.apply(&value).view(sizes).transpose(1, 2);
       
    89         let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64));
       
    90         let att = att.masked_fill(
       
    91             &mask.i((.., .., ..sz_t, ..sz_t)).eq(0.),
       
    92             std::f64::NEG_INFINITY,
       
    93         );
       
    94         let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train);
       
    95         let ys = att
       
    96             .matmul(&v)
       
    97             .transpose(1, 2)
       
    98             .contiguous()
       
    99             .view([sz_b, sz_t, sz_c]);
       
   100         ys.apply(&proj).dropout(cfg.resid_pdrop, train)
       
   101     })
       
   102 }
       
   103 
       
   104 fn block(p: &nn::Path, cfg: Config) -> impl ModuleT {
       
   105     let ln1 = nn::layer_norm(p / "ln1", vec![cfg.n_embd], Default::default());
       
   106     let ln2 = nn::layer_norm(p / "ln2", vec![cfg.n_embd], Default::default());
       
   107     let attn = causal_self_attention(p, cfg);
       
   108     let lin1 = linear(p / "lin1", cfg.n_embd, 4 * cfg.n_embd);
       
   109     let lin2 = linear(p / "lin2", 4 * cfg.n_embd, cfg.n_embd);
       
   110     nn::func_t(move |xs, train| {
       
   111         let xs = xs + xs.apply(&ln1).apply_t(&attn, train);
       
   112         let ys = xs
       
   113             .apply(&ln2)
       
   114             .apply(&lin1)
       
   115             .gelu()
       
   116             .apply(&lin2)
       
   117             .dropout(cfg.resid_pdrop, train);
       
   118         xs + ys
       
   119     })
       
   120 }
       
   121 
       
   122 fn gpt(p: &nn::Path, cfg: Config) -> impl ModuleT {
       
   123     let p = &p.set_group(NO_WEIGHT_DECAY_GROUP);
       
   124     let tok_emb = nn::embedding(
       
   125         p / "tok_emb",
       
   126         cfg.vocab_size,
       
   127         cfg.n_embd,
       
   128         Default::default(),
       
   129     );
       
   130     let pos_emb = p.zeros("pos_emb", &[1, cfg.block_size, cfg.n_embd]);
       
   131     let ln_f = nn::layer_norm(p / "ln_f", vec![cfg.n_embd], Default::default());
       
   132     let head = linear_no_bias(p / "head", cfg.n_embd, cfg.vocab_size);
       
   133     let mut blocks = nn::seq_t();
       
   134     for block_idx in 0..cfg.n_layer {
       
   135         blocks = blocks.add(block(&(p / block_idx), cfg));
       
   136     }
       
   137     nn::func_t(move |xs, train| {
       
   138         let (_sz_b, sz_t) = xs.size2().unwrap();
       
   139         let tok_emb = xs.apply(&tok_emb);
       
   140         let pos_emb = pos_emb.i((.., ..sz_t, ..));
       
   141         (tok_emb + pos_emb)
       
   142             .dropout(cfg.embd_pdrop, train)
       
   143             .apply_t(&blocks, train)
       
   144             .apply(&ln_f)
       
   145             .apply(&head)
       
   146     })
       
   147 }
       
   148 
       
   149 /// Generates some sample string using the GPT model.
       
   150 fn sample(data: &TextData, gpt: &impl ModuleT, input: Tensor) -> String {
       
   151     let mut input = input;
       
   152     let mut result = String::new();
       
   153     for _index in 0..SAMPLING_LEN {
       
   154         let logits = input.apply_t(gpt, false).i((0, -1, ..));
       
   155         let sampled_y = logits.softmax(-1, Kind::Float).multinomial(1, true);
       
   156         let last_label = i64::from(&sampled_y);
       
   157         result.push(data.label_to_char(last_label));
       
   158         input = Tensor::cat(&[input, sampled_y.view([1, 1])], 1).narrow(1, 1, BLOCK_SIZE);
       
   159     }
       
   160     result
       
   161 }
       
   162 
       
   163 #[tokio::main]
       
   164 async fn main() -> AHResult<()> {
       
   165     let device = Device::cuda_if_available();
       
   166     let mut vs = nn::VarStore::new(device);
       
   167     let data = TextData::new("10.log")?;
       
   168     let labels = data.labels();
       
   169     println!("Dataset loaded, {} labels.", labels);
       
   170     let cfg = Config {
       
   171         vocab_size: labels,
       
   172         n_embd: 384, // was 512
       
   173         n_head: 8,
       
   174         n_layer: 8,
       
   175         block_size: BLOCK_SIZE,
       
   176         attn_pdrop: 0.1,
       
   177         resid_pdrop: 0.1,
       
   178         embd_pdrop: 0.1,
       
   179     };
       
   180     let gpt = gpt(&(&vs.root() / "gpt"), cfg);
       
   181     let args: Vec<_> = std::env::args().collect();
       
   182     if args.len() < 2 {
       
   183         bail!("usage: main (train|predict weights.ot seqstart)")
       
   184     }
       
   185     match args[1].as_str() {
       
   186         "train" => {
       
   187             let mut opt = nn::AdamW::default().build(&vs, LEARNING_RATE)?;
       
   188             opt.set_weight_decay_group(NO_WEIGHT_DECAY_GROUP, 0.0);
       
   189             opt.set_weight_decay_group(WEIGHT_DECAY_GROUP, 0.1);
       
   190             let mut idx = 0;
       
   191             vs.load("384.ot")?;
       
   192             for epoch in 1..(1 + EPOCHS) {
       
   193                 let mut sum_loss = 0.;
       
   194                 let mut cnt_loss = 0.;
       
   195                 for batch in data.iter_shuffle(BLOCK_SIZE + 1, BATCH_SIZE) {
       
   196                     let xs = batch
       
   197                         .narrow(1, 0, BLOCK_SIZE)
       
   198                         .to_kind(Kind::Int64)
       
   199                         .to_device(device);
       
   200                     let ys = batch
       
   201                         .narrow(1, 1, BLOCK_SIZE)
       
   202                         .to_kind(Kind::Int64)
       
   203                         .to_device(device);
       
   204                     let logits = xs.apply_t(&gpt, true);
       
   205                     let loss = logits
       
   206                         .view([BATCH_SIZE * BLOCK_SIZE, labels])
       
   207                         .cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE]));
       
   208                     opt.backward_step_clip(&loss, 0.5);
       
   209                     sum_loss += f64::from(loss);
       
   210                     cnt_loss += 1.0;
       
   211                     idx += 1;
       
   212                     if idx % 10 == 0 {
       
   213                         print!("{}", '.');
       
   214                         io::stdout().flush()?;
       
   215                     }
       
   216                     if idx % 1000 == 0 {
       
   217                         println!("Epoch: {}   loss: {:5.3}", epoch, sum_loss / cnt_loss);
       
   218                         let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
       
   219                         println!("Sample: {}", sample(&data, &gpt, input));
       
   220                         if let Err(err) = vs.save(format!("gpt{:08}.ot", idx)) {
       
   221                             println!("error while saving {}", err);
       
   222                         }
       
   223                         sum_loss = 0.;
       
   224                         cnt_loss = 0.;
       
   225                     }
       
   226                 }
       
   227             }
       
   228         }
       
   229         "predict" => {
       
   230             let amqp_url = std::env::var("AMQP_URL").expect("expected AMQP_URL env variabe");
       
   231             let conn = Connection::connect(&amqp_url, ConnectionProperties::default().with_tokio())
       
   232                 .await?;
       
   233 
       
   234             let pub_channel = conn.create_channel().await?;
       
   235             let sub_channel = conn.create_channel().await?;
       
   236 
       
   237             let queue = sub_channel
       
   238                 .queue_declare(
       
   239                     &"",
       
   240                     QueueDeclareOptions {
       
   241                         exclusive: true,
       
   242                         auto_delete: true,
       
   243                         ..QueueDeclareOptions::default()
       
   244                     },
       
   245                     FieldTable::default(),
       
   246                 )
       
   247                 .await?;
       
   248 
       
   249             sub_channel
       
   250                 .queue_bind(
       
   251                     queue.name().as_str(),
       
   252                     "irc",
       
   253                     "cmd.say.hedgewars",
       
   254                     QueueBindOptions::default(),
       
   255                     FieldTable::default(),
       
   256                 )
       
   257                 .await?;
       
   258 
       
   259             let mut subscriber = sub_channel
       
   260                 .basic_consume(
       
   261                     queue.name().as_str(),
       
   262                     &"",
       
   263                     BasicConsumeOptions::default(),
       
   264                     FieldTable::default(),
       
   265                 )
       
   266                 .await?;
       
   267 
       
   268             vs.load(args[2].as_str())?;
       
   269 
       
   270             while let Some(amqp_message) = subscriber.next().await {
       
   271                 let (_, delivery) = amqp_message.expect("error in consumer");
       
   272                 delivery.ack(BasicAckOptions::default()).await?;
       
   273 
       
   274                 let chat_message = String::from_utf8(delivery.data)?;
       
   275                 if let Some((_who, seed)) = chat_message.split_once('\n') {
       
   276                     let input_sample = &format!("\n{}", seed);
       
   277                     let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device));
       
   278                     for (idx, c) in input_sample.chars().rev().enumerate() {
       
   279                         let idx = idx as i64;
       
   280                         if idx >= BLOCK_SIZE {
       
   281                             break;
       
   282                         }
       
   283                         let _filled = input
       
   284                             .i((0, BLOCK_SIZE - 1 - idx))
       
   285                             .fill_(data.char_to_label(c).unwrap_or(0) as i64);
       
   286                     }
       
   287 
       
   288                     let proceeded_message = &sample(&data, &gpt, input);
       
   289                     let final_message = proceeded_message
       
   290                         .split_once('\n')
       
   291                         .map(|(m, _)| m)
       
   292                         .unwrap_or(proceeded_message);
       
   293                     let final_message = &format!("{}{}", seed, final_message);
       
   294 
       
   295                     println!("{} --> {}", seed, proceeded_message);
       
   296 
       
   297                     pub_channel
       
   298                         .basic_publish(
       
   299                             "irc",
       
   300                             "say.hedgewars",
       
   301                             BasicPublishOptions::default(),
       
   302                             final_message.as_bytes().to_vec(),
       
   303                             BasicProperties::default(),
       
   304                         )
       
   305                         .await?;
       
   306                 }
       
   307             }
       
   308         }
       
   309         _ => bail!("usage: main (train|predict weights.ot)"),
       
   310     };
       
   311 
       
   312     Ok(())
       
   313 }