|
1 /* This example uses the tinyshakespeare dataset which can be downloaded at: |
|
2 https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt |
|
3 |
|
4 This is mostly a rust port of https://github.com/karpathy/minGPT |
|
5 */ |
|
6 |
|
7 extern crate tch; |
|
8 use anyhow::{bail, Result as AHResult}; |
|
9 use std::{io, io::Write}; |
|
10 use tch::data::TextData; |
|
11 use tch::nn::{ModuleT, OptimizerConfig}; |
|
12 use tch::{nn, Device, IndexOp, Kind, Tensor}; |
|
13 |
|
14 use futures::prelude::*; |
|
15 use lapin::{options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties}; |
|
16 |
|
17 use tokio_amqp::*; |
|
18 |
|
19 const LEARNING_RATE: f64 = 0.0003; |
|
20 const BLOCK_SIZE: i64 = 128; |
|
21 const BATCH_SIZE: i64 = 64; |
|
22 const EPOCHS: i64 = 100; |
|
23 const SAMPLING_LEN: i64 = 512; |
|
24 |
|
25 #[derive(Debug, Copy, Clone)] |
|
26 struct Config { |
|
27 vocab_size: i64, |
|
28 n_embd: i64, |
|
29 n_head: i64, |
|
30 n_layer: i64, |
|
31 block_size: i64, |
|
32 attn_pdrop: f64, |
|
33 resid_pdrop: f64, |
|
34 embd_pdrop: f64, |
|
35 } |
|
36 |
|
37 // Weight decay only applies to the weight matrixes in the linear layers |
|
38 const NO_WEIGHT_DECAY_GROUP: usize = 0; |
|
39 const WEIGHT_DECAY_GROUP: usize = 1; |
|
40 |
|
41 // Custom linear layer so that different groups can be used for weight |
|
42 // and biases. |
|
43 #[derive(Debug)] |
|
44 struct Linear { |
|
45 pub ws: Tensor, |
|
46 pub bs: Tensor, |
|
47 } |
|
48 |
|
49 impl nn::Module for Linear { |
|
50 fn forward(&self, xs: &Tensor) -> Tensor { |
|
51 xs.matmul(&self.ws.tr()) + &self.bs |
|
52 } |
|
53 } |
|
54 |
|
55 fn linear(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear { |
|
56 let wd = vs.set_group(WEIGHT_DECAY_GROUP); |
|
57 let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP); |
|
58 Linear { |
|
59 ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02), |
|
60 bs: no_wd.zeros("bias", &[out_dim]), |
|
61 } |
|
62 } |
|
63 |
|
64 fn linear_no_bias(vs: nn::Path, in_dim: i64, out_dim: i64) -> Linear { |
|
65 let wd = vs.set_group(WEIGHT_DECAY_GROUP); |
|
66 let no_wd = vs.set_group(NO_WEIGHT_DECAY_GROUP); |
|
67 Linear { |
|
68 ws: wd.randn("weight", &[out_dim, in_dim], 0.0, 0.02), |
|
69 bs: no_wd.zeros_no_train("bias", &[out_dim]), |
|
70 } |
|
71 } |
|
72 |
|
73 fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT { |
|
74 let key = linear(p / "key", cfg.n_embd, cfg.n_embd); |
|
75 let query = linear(p / "query", cfg.n_embd, cfg.n_embd); |
|
76 let value = linear(p / "value", cfg.n_embd, cfg.n_embd); |
|
77 let proj = linear(p / "proj", cfg.n_embd, cfg.n_embd); |
|
78 let mask_init = |
|
79 Tensor::ones(&[cfg.block_size, cfg.block_size], (Kind::Float, p.device())).tril(0); |
|
80 let mask_init = mask_init.view([1, 1, cfg.block_size, cfg.block_size]); |
|
81 // let mask = p.var_copy("mask", &mask_init); |
|
82 let mask = mask_init; |
|
83 nn::func_t(move |xs, train| { |
|
84 let (sz_b, sz_t, sz_c) = xs.size3().unwrap(); |
|
85 let sizes = [sz_b, sz_t, cfg.n_head, sz_c / cfg.n_head]; |
|
86 let k = xs.apply(&key).view(sizes).transpose(1, 2); |
|
87 let q = xs.apply(&query).view(sizes).transpose(1, 2); |
|
88 let v = xs.apply(&value).view(sizes).transpose(1, 2); |
|
89 let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64)); |
|
90 let att = att.masked_fill( |
|
91 &mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), |
|
92 std::f64::NEG_INFINITY, |
|
93 ); |
|
94 let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train); |
|
95 let ys = att |
|
96 .matmul(&v) |
|
97 .transpose(1, 2) |
|
98 .contiguous() |
|
99 .view([sz_b, sz_t, sz_c]); |
|
100 ys.apply(&proj).dropout(cfg.resid_pdrop, train) |
|
101 }) |
|
102 } |
|
103 |
|
104 fn block(p: &nn::Path, cfg: Config) -> impl ModuleT { |
|
105 let ln1 = nn::layer_norm(p / "ln1", vec![cfg.n_embd], Default::default()); |
|
106 let ln2 = nn::layer_norm(p / "ln2", vec![cfg.n_embd], Default::default()); |
|
107 let attn = causal_self_attention(p, cfg); |
|
108 let lin1 = linear(p / "lin1", cfg.n_embd, 4 * cfg.n_embd); |
|
109 let lin2 = linear(p / "lin2", 4 * cfg.n_embd, cfg.n_embd); |
|
110 nn::func_t(move |xs, train| { |
|
111 let xs = xs + xs.apply(&ln1).apply_t(&attn, train); |
|
112 let ys = xs |
|
113 .apply(&ln2) |
|
114 .apply(&lin1) |
|
115 .gelu() |
|
116 .apply(&lin2) |
|
117 .dropout(cfg.resid_pdrop, train); |
|
118 xs + ys |
|
119 }) |
|
120 } |
|
121 |
|
122 fn gpt(p: &nn::Path, cfg: Config) -> impl ModuleT { |
|
123 let p = &p.set_group(NO_WEIGHT_DECAY_GROUP); |
|
124 let tok_emb = nn::embedding( |
|
125 p / "tok_emb", |
|
126 cfg.vocab_size, |
|
127 cfg.n_embd, |
|
128 Default::default(), |
|
129 ); |
|
130 let pos_emb = p.zeros("pos_emb", &[1, cfg.block_size, cfg.n_embd]); |
|
131 let ln_f = nn::layer_norm(p / "ln_f", vec![cfg.n_embd], Default::default()); |
|
132 let head = linear_no_bias(p / "head", cfg.n_embd, cfg.vocab_size); |
|
133 let mut blocks = nn::seq_t(); |
|
134 for block_idx in 0..cfg.n_layer { |
|
135 blocks = blocks.add(block(&(p / block_idx), cfg)); |
|
136 } |
|
137 nn::func_t(move |xs, train| { |
|
138 let (_sz_b, sz_t) = xs.size2().unwrap(); |
|
139 let tok_emb = xs.apply(&tok_emb); |
|
140 let pos_emb = pos_emb.i((.., ..sz_t, ..)); |
|
141 (tok_emb + pos_emb) |
|
142 .dropout(cfg.embd_pdrop, train) |
|
143 .apply_t(&blocks, train) |
|
144 .apply(&ln_f) |
|
145 .apply(&head) |
|
146 }) |
|
147 } |
|
148 |
|
149 /// Generates some sample string using the GPT model. |
|
150 fn sample(data: &TextData, gpt: &impl ModuleT, input: Tensor) -> String { |
|
151 let mut input = input; |
|
152 let mut result = String::new(); |
|
153 for _index in 0..SAMPLING_LEN { |
|
154 let logits = input.apply_t(gpt, false).i((0, -1, ..)); |
|
155 let sampled_y = logits.softmax(-1, Kind::Float).multinomial(1, true); |
|
156 let last_label = i64::from(&sampled_y); |
|
157 result.push(data.label_to_char(last_label)); |
|
158 input = Tensor::cat(&[input, sampled_y.view([1, 1])], 1).narrow(1, 1, BLOCK_SIZE); |
|
159 } |
|
160 result |
|
161 } |
|
162 |
|
163 #[tokio::main] |
|
164 async fn main() -> AHResult<()> { |
|
165 let device = Device::cuda_if_available(); |
|
166 let mut vs = nn::VarStore::new(device); |
|
167 let data = TextData::new("10.log")?; |
|
168 let labels = data.labels(); |
|
169 println!("Dataset loaded, {} labels.", labels); |
|
170 let cfg = Config { |
|
171 vocab_size: labels, |
|
172 n_embd: 384, // was 512 |
|
173 n_head: 8, |
|
174 n_layer: 8, |
|
175 block_size: BLOCK_SIZE, |
|
176 attn_pdrop: 0.1, |
|
177 resid_pdrop: 0.1, |
|
178 embd_pdrop: 0.1, |
|
179 }; |
|
180 let gpt = gpt(&(&vs.root() / "gpt"), cfg); |
|
181 let args: Vec<_> = std::env::args().collect(); |
|
182 if args.len() < 2 { |
|
183 bail!("usage: main (train|predict weights.ot seqstart)") |
|
184 } |
|
185 match args[1].as_str() { |
|
186 "train" => { |
|
187 let mut opt = nn::AdamW::default().build(&vs, LEARNING_RATE)?; |
|
188 opt.set_weight_decay_group(NO_WEIGHT_DECAY_GROUP, 0.0); |
|
189 opt.set_weight_decay_group(WEIGHT_DECAY_GROUP, 0.1); |
|
190 let mut idx = 0; |
|
191 vs.load("384.ot")?; |
|
192 for epoch in 1..(1 + EPOCHS) { |
|
193 let mut sum_loss = 0.; |
|
194 let mut cnt_loss = 0.; |
|
195 for batch in data.iter_shuffle(BLOCK_SIZE + 1, BATCH_SIZE) { |
|
196 let xs = batch |
|
197 .narrow(1, 0, BLOCK_SIZE) |
|
198 .to_kind(Kind::Int64) |
|
199 .to_device(device); |
|
200 let ys = batch |
|
201 .narrow(1, 1, BLOCK_SIZE) |
|
202 .to_kind(Kind::Int64) |
|
203 .to_device(device); |
|
204 let logits = xs.apply_t(&gpt, true); |
|
205 let loss = logits |
|
206 .view([BATCH_SIZE * BLOCK_SIZE, labels]) |
|
207 .cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE])); |
|
208 opt.backward_step_clip(&loss, 0.5); |
|
209 sum_loss += f64::from(loss); |
|
210 cnt_loss += 1.0; |
|
211 idx += 1; |
|
212 if idx % 10 == 0 { |
|
213 print!("{}", '.'); |
|
214 io::stdout().flush()?; |
|
215 } |
|
216 if idx % 1000 == 0 { |
|
217 println!("Epoch: {} loss: {:5.3}", epoch, sum_loss / cnt_loss); |
|
218 let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
|
219 println!("Sample: {}", sample(&data, &gpt, input)); |
|
220 if let Err(err) = vs.save(format!("gpt{:08}.ot", idx)) { |
|
221 println!("error while saving {}", err); |
|
222 } |
|
223 sum_loss = 0.; |
|
224 cnt_loss = 0.; |
|
225 } |
|
226 } |
|
227 } |
|
228 } |
|
229 "predict" => { |
|
230 let amqp_url = std::env::var("AMQP_URL").expect("expected AMQP_URL env variabe"); |
|
231 let conn = Connection::connect(&amqp_url, ConnectionProperties::default().with_tokio()) |
|
232 .await?; |
|
233 |
|
234 let pub_channel = conn.create_channel().await?; |
|
235 let sub_channel = conn.create_channel().await?; |
|
236 |
|
237 let queue = sub_channel |
|
238 .queue_declare( |
|
239 &"", |
|
240 QueueDeclareOptions { |
|
241 exclusive: true, |
|
242 auto_delete: true, |
|
243 ..QueueDeclareOptions::default() |
|
244 }, |
|
245 FieldTable::default(), |
|
246 ) |
|
247 .await?; |
|
248 |
|
249 sub_channel |
|
250 .queue_bind( |
|
251 queue.name().as_str(), |
|
252 "irc", |
|
253 "cmd.say.hedgewars", |
|
254 QueueBindOptions::default(), |
|
255 FieldTable::default(), |
|
256 ) |
|
257 .await?; |
|
258 |
|
259 let mut subscriber = sub_channel |
|
260 .basic_consume( |
|
261 queue.name().as_str(), |
|
262 &"", |
|
263 BasicConsumeOptions::default(), |
|
264 FieldTable::default(), |
|
265 ) |
|
266 .await?; |
|
267 |
|
268 vs.load(args[2].as_str())?; |
|
269 |
|
270 while let Some(amqp_message) = subscriber.next().await { |
|
271 let (_, delivery) = amqp_message.expect("error in consumer"); |
|
272 delivery.ack(BasicAckOptions::default()).await?; |
|
273 |
|
274 let chat_message = String::from_utf8(delivery.data)?; |
|
275 if let Some((_who, seed)) = chat_message.split_once('\n') { |
|
276 let input_sample = &format!("\n{}", seed); |
|
277 let input = Tensor::zeros(&[1, BLOCK_SIZE], (Kind::Int64, device)); |
|
278 for (idx, c) in input_sample.chars().rev().enumerate() { |
|
279 let idx = idx as i64; |
|
280 if idx >= BLOCK_SIZE { |
|
281 break; |
|
282 } |
|
283 let _filled = input |
|
284 .i((0, BLOCK_SIZE - 1 - idx)) |
|
285 .fill_(data.char_to_label(c).unwrap_or(0) as i64); |
|
286 } |
|
287 |
|
288 let proceeded_message = &sample(&data, &gpt, input); |
|
289 let final_message = proceeded_message |
|
290 .split_once('\n') |
|
291 .map(|(m, _)| m) |
|
292 .unwrap_or(proceeded_message); |
|
293 let final_message = &format!("{}{}", seed, final_message); |
|
294 |
|
295 println!("{} --> {}", seed, proceeded_message); |
|
296 |
|
297 pub_channel |
|
298 .basic_publish( |
|
299 "irc", |
|
300 "say.hedgewars", |
|
301 BasicPublishOptions::default(), |
|
302 final_message.as_bytes().to_vec(), |
|
303 BasicProperties::default(), |
|
304 ) |
|
305 .await?; |
|
306 } |
|
307 } |
|
308 } |
|
309 _ => bail!("usage: main (train|predict weights.ot)"), |
|
310 }; |
|
311 |
|
312 Ok(()) |
|
313 } |