tools/ubot-plugins/url-bot-rs/src/bin/ubot-url-plugin.rs
changeset 15790 efe4e3290870
equal deleted inserted replaced
15789:d97ea528ce95 15790:efe4e3290870
       
     1 use url_bot_rs::config::Rtd;
       
     2 use url_bot_rs::VERSION;
       
     3 use url_bot_rs::{feat, http::resolve_url, param, plugins::TITLE_PLUGINS, tld::TLD};
       
     4 
       
     5 use anyhow::Result as AHResult;
       
     6 use atty::{is, Stream};
       
     7 use directories::ProjectDirs;
       
     8 use docopt::Docopt;
       
     9 use failure::Error;
       
    10 use lazy_static::lazy_static;
       
    11 use log::{error, info};
       
    12 use regex::Regex;
       
    13 use reqwest::Url;
       
    14 use serde_derive::Deserialize;
       
    15 use std::collections::HashSet;
       
    16 use std::path::PathBuf;
       
    17 use stderrlog::{ColorChoice, Timestamp};
       
    18 
       
    19 use lapin::{options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties};
       
    20 use tokio_amqp::*;
       
    21 
       
    22 use futures::prelude::*;
       
    23 
       
    24 use rand::distributions::Alphanumeric;
       
    25 use rand::{thread_rng, Rng};
       
    26 
       
    27 use std::sync::mpsc;
       
    28 use std::thread;
       
    29 
       
    30 // docopt usage string
       
    31 const USAGE: &str = "
       
    32 URL munching IRC bot.
       
    33 
       
    34 Usage:
       
    35     ubot-url-plugin [options] [-v...] [--conf=PATH...] [--conf-dir=DIR...]
       
    36 
       
    37 Options:
       
    38     -h --help           Show this help message.
       
    39     --version           Print version.
       
    40     -v --verbose        Show extra information.
       
    41     -t --timestamp      Force timestamps.
       
    42 ";
       
    43 
       
    44 #[derive(Debug, Deserialize, Default)]
       
    45 pub struct Args {
       
    46     flag_verbose: usize,
       
    47     flag_conf: Vec<PathBuf>,
       
    48     flag_conf_dir: Vec<PathBuf>,
       
    49     flag_timestamp: bool,
       
    50 }
       
    51 
       
    52 const MIN_VERBOSITY: usize = 2;
       
    53 
       
    54 #[derive(Debug, PartialEq)]
       
    55 enum TitleResp {
       
    56     Title(String),
       
    57     Error(String),
       
    58 }
       
    59 
       
    60 /// Run available plugins on a single URL, return the first successful title.
       
    61 fn process_plugins(rtd: &Rtd, url: &Url) -> Option<String> {
       
    62     let result: String = TITLE_PLUGINS
       
    63         .iter()
       
    64         .filter(|p| p.check(&rtd.conf.plugins, url))
       
    65         .filter_map(|p| p.evaluate(&rtd, url).ok())
       
    66         .take(1)
       
    67         .collect();
       
    68 
       
    69     if result.is_empty() {
       
    70         None
       
    71     } else {
       
    72         Some(result)
       
    73     }
       
    74 }
       
    75 
       
    76 /// find titles in a message and generate responses
       
    77 fn process_titles(rtd: &Rtd, msg: &str) -> impl Iterator<Item = TitleResp> {
       
    78     let mut responses: Vec<TitleResp> = vec![];
       
    79 
       
    80     let mut num_processed = 0;
       
    81     let mut dedup_urls = HashSet::new();
       
    82 
       
    83     // look at each space-separated message token
       
    84     for token in msg.split_whitespace() {
       
    85         // the token must not contain unsafe characters
       
    86         if contains_unsafe_chars(token) {
       
    87             continue;
       
    88         }
       
    89 
       
    90         // get a full URL for tokens without a scheme
       
    91         let maybe_token = if feat!(rtd, partial_urls) {
       
    92             add_scheme_for_tld(token)
       
    93         } else {
       
    94             None
       
    95         };
       
    96 
       
    97         let token = maybe_token.as_ref().map_or(token, String::as_str);
       
    98 
       
    99         // the token must be a valid url
       
   100         let url = match token.parse::<Url>() {
       
   101             Ok(url) => url,
       
   102             _ => continue,
       
   103         };
       
   104 
       
   105         // the scheme must be http or https
       
   106         if !["http", "https"].contains(&url.scheme()) {
       
   107             continue;
       
   108         }
       
   109 
       
   110         // skip duplicate urls within the message
       
   111         if dedup_urls.contains(&url) {
       
   112             continue;
       
   113         }
       
   114 
       
   115         info!("[{}] RESOLVE <{}>", rtd.conf.network.name, token);
       
   116 
       
   117         // try to get the title from the url
       
   118         let title = if let Some(title) = process_plugins(rtd, &url) {
       
   119             title
       
   120         } else {
       
   121             match resolve_url(token, rtd) {
       
   122                 Ok(title) => title,
       
   123                 Err(err) => {
       
   124                     error!("{:?}", err);
       
   125                     responses.push(TitleResp::Error(err.to_string()));
       
   126                     continue;
       
   127                 }
       
   128             }
       
   129         };
       
   130 
       
   131         // limit response length, see RFC1459
       
   132 
       
   133         let msg = utf8_truncate(&format!("⤷ {}", title), 510);
       
   134 
       
   135         info!("[{}] {}", rtd.conf.network.name, msg);
       
   136 
       
   137         responses.push(TitleResp::Title(msg.to_string()));
       
   138 
       
   139         dedup_urls.insert(url);
       
   140 
       
   141         // limit the number of processed URLs
       
   142         num_processed += 1;
       
   143         if num_processed == param!(rtd, url_limit) {
       
   144             break;
       
   145         }
       
   146     }
       
   147 
       
   148     responses.into_iter()
       
   149 }
       
   150 
       
   151 // regex for unsafe characters, as defined in RFC 1738
       
   152 const RE_UNSAFE_CHARS: &str = r#"[{}|\\^~\[\]`<>"]"#;
       
   153 
       
   154 /// does the token contain characters not permitted by RFC 1738
       
   155 fn contains_unsafe_chars(token: &str) -> bool {
       
   156     lazy_static! {
       
   157         static ref UNSAFE: Regex = Regex::new(RE_UNSAFE_CHARS).unwrap();
       
   158     }
       
   159     UNSAFE.is_match(token)
       
   160 }
       
   161 
       
   162 /// truncate to a maximum number of bytes, taking UTF-8 into account
       
   163 fn utf8_truncate(s: &str, n: usize) -> String {
       
   164     s.char_indices()
       
   165         .take_while(|(len, c)| len + c.len_utf8() <= n)
       
   166         .map(|(_, c)| c)
       
   167         .collect()
       
   168 }
       
   169 
       
   170 lazy_static! {
       
   171     static ref REPEATED_DOTS: Regex = Regex::new(r"\.\.+").unwrap();
       
   172 }
       
   173 
       
   174 /// if a token has a recognised TLD, but no scheme, add one
       
   175 pub fn add_scheme_for_tld(token: &str) -> Option<String> {
       
   176     if token.parse::<Url>().is_err() {
       
   177         if token.starts_with(|s: char| !s.is_alphabetic()) {
       
   178             return None;
       
   179         }
       
   180 
       
   181         if REPEATED_DOTS.is_match(&token) {
       
   182             return None;
       
   183         }
       
   184 
       
   185         let new_token = format!("http://{}", token);
       
   186 
       
   187         if let Ok(url) = new_token.parse::<Url>() {
       
   188             if !url.domain()?.contains('.') {
       
   189                 return None;
       
   190             }
       
   191 
       
   192             // reject email addresses
       
   193             if url.username() != "" {
       
   194                 return None;
       
   195             }
       
   196 
       
   197             let tld = url.domain()?.split('.').last()?;
       
   198 
       
   199             if TLD.contains(tld) {
       
   200                 return Some(new_token);
       
   201             }
       
   202         }
       
   203     }
       
   204 
       
   205     None
       
   206 }
       
   207 
       
   208 fn init_rtd() -> AHResult<Rtd, Error> {
       
   209     // parse command line arguments with docopt
       
   210     let args: Args = Docopt::new(USAGE)
       
   211         .and_then(|d| d.version(Some(VERSION.to_string())).deserialize())
       
   212         .unwrap_or_else(|e| e.exit());
       
   213 
       
   214     // avoid timestamping when piped, e.g. systemd
       
   215     let timestamp = if is(Stream::Stderr) || args.flag_timestamp {
       
   216         Timestamp::Second
       
   217     } else {
       
   218         Timestamp::Off
       
   219     };
       
   220 
       
   221     stderrlog::new()
       
   222         .module(module_path!())
       
   223         .modules(vec![
       
   224             "url_bot_rs::message",
       
   225             "url_bot_rs::config",
       
   226             "url_bot_rs::http",
       
   227         ])
       
   228         .verbosity(args.flag_verbose + MIN_VERBOSITY)
       
   229         .timestamp(timestamp)
       
   230         .color(ColorChoice::Never)
       
   231         .init()
       
   232         .unwrap();
       
   233 
       
   234     let dirs = ProjectDirs::from("org", "", "url-bot-rs").unwrap();
       
   235     let default_conf_dir = dirs.config_dir();
       
   236 
       
   237     let default_conf = default_conf_dir.join("config.toml");
       
   238 
       
   239     let rtd: Rtd = Rtd::new().conf(&default_conf).load()?.init_http_client()?;
       
   240 
       
   241     Ok(rtd)
       
   242 }
       
   243 
       
   244 fn random_string(size: usize) -> String {
       
   245     thread_rng()
       
   246         .sample_iter(&Alphanumeric)
       
   247         .take(size)
       
   248         .map(char::from)
       
   249         .collect()
       
   250 }
       
   251 
       
   252 #[tokio::main]
       
   253 async fn main() -> AHResult<()> {
       
   254     let (tx1, rx1) = mpsc::channel::<String>();
       
   255     let (tx2, rx2) = mpsc::channel();
       
   256 
       
   257     thread::spawn(move || {
       
   258         let rtd = init_rtd().expect("RTD not initialized");
       
   259 
       
   260         loop {
       
   261             let message = &rx1.recv().expect("rx1 recv error");
       
   262             let titles: Vec<_> = process_titles(&rtd, message).collect();
       
   263             tx2.send(titles).expect("tx2 send error");
       
   264         }
       
   265     });
       
   266     let amqp_url = std::env::var("AMQP_URL").expect("expected AMQP_URL env variabe");
       
   267     let conn = Connection::connect(&amqp_url, ConnectionProperties::default().with_tokio()).await?;
       
   268 
       
   269     let pub_channel = conn.create_channel().await?;
       
   270     let sub_channel = conn.create_channel().await?;
       
   271 
       
   272     let queue = sub_channel
       
   273         .queue_declare(
       
   274             &random_string(32),
       
   275             QueueDeclareOptions {
       
   276                 exclusive: true,
       
   277                 auto_delete: true,
       
   278                 ..QueueDeclareOptions::default()
       
   279             },
       
   280             FieldTable::default(),
       
   281         )
       
   282         .await?;
       
   283 
       
   284     sub_channel
       
   285         .queue_bind(
       
   286             queue.name().as_str(),
       
   287             "irc",
       
   288             "msg.hedgewars",
       
   289             QueueBindOptions::default(),
       
   290             FieldTable::default(),
       
   291         )
       
   292         .await?;
       
   293 
       
   294     let mut subscriber = sub_channel
       
   295         .basic_consume(
       
   296             queue.name().as_str(),
       
   297             &random_string(32),
       
   298             BasicConsumeOptions::default(),
       
   299             FieldTable::default(),
       
   300         )
       
   301         .await?;
       
   302 
       
   303     while let Some(amqp_message) = subscriber.next().await {
       
   304         let (_, delivery) = amqp_message.expect("error in consumer");
       
   305         delivery.ack(BasicAckOptions::default()).await?;
       
   306 
       
   307         let chat_message = String::from_utf8(delivery.data)?;
       
   308         if let Some((_who, message)) = chat_message.split_once('\n') {
       
   309             tx1.send(message.to_owned())?;
       
   310             let titles = rx2.recv()?;
       
   311 
       
   312             for title in titles {
       
   313                 let title_message = match title {
       
   314                     TitleResp::Title(t) => t,
       
   315                     TitleResp::Error(e) => e,
       
   316                 };
       
   317                 pub_channel
       
   318                     .basic_publish(
       
   319                         "irc",
       
   320                         "say.hedgewars",
       
   321                         BasicPublishOptions::default(),
       
   322                         title_message.as_bytes().to_vec(),
       
   323                         BasicProperties::default(),
       
   324                     )
       
   325                     .await?;
       
   326             }
       
   327         }
       
   328     }
       
   329 
       
   330     Ok(())
       
   331 }