diff --git a/src/config.rs b/src/config.rs index 0edfb4e..ff2a98e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,12 @@ +use crate::error::Error; +use crate::imgur; +use crate::imgur::Image; use config::{Config, File}; +use rand::prelude::SliceRandom; use serde::{Deserialize, Serialize}; use serenity::model::prelude::UserId; use serenity::prelude::TypeMapKey; +use std::collections::HashMap; use std::path::{Path, PathBuf}; use structopt::StructOpt; use tera::Tera; @@ -50,6 +55,64 @@ impl BotConfig { pub struct BotState { pub accepted_nsfw: Option, pub fortune_templates: Tera, + pub albums: HashMap>, +} + +impl BotState { + pub async fn new(cfg: &BotConfig) -> Result { + let mut fortune_templates = Tera::default(); + let mut albums: HashMap> = HashMap::new(); + + for (idx, fortune) in cfg.fortunes.iter().enumerate() { + fortune_templates + .add_raw_template(&idx.to_string(), fortune) + .unwrap(); + } + + for album in &cfg.albums { + albums.insert( + album.name.clone(), + imgur::get_album_images(&cfg.imgur_client_id, &album.album_id).await?, + ); + } + + Ok(Self { + accepted_nsfw: None, + fortune_templates, + albums, + }) + } + + pub fn get_image(&self, album_name: &str, tags: Vec<&str>) -> Option { + let mut rng = rand::thread_rng(); + + let album = match self.albums.get(album_name) { + None => return None, + Some(a) => a, + }; + + let album = if tags.is_empty() { + album.clone() + } else { + album + .iter() + .filter(|img| { + for tag in &tags { + if let Some(desc) = &img.description { + if desc.to_lowercase().contains(&tag.to_lowercase()) { + return true; + } + } + } + + false + }) + .cloned() + .collect() + }; + + album.choose(&mut rng).cloned() + } } #[derive(Debug)] @@ -60,23 +123,21 @@ pub struct GlobalData { } impl GlobalData { - pub fn new(args: Args, cfg: BotConfig) -> Self { - let mut fortune_templates = Tera::default(); - - for (idx, fortune) in cfg.fortunes.iter().enumerate() { - fortune_templates - .add_raw_template(&idx.to_string(), fortune) - .unwrap(); - } - - Self { + pub async fn new(args: Args, cfg: BotConfig) -> Result { + Ok(Self { args, + bot_state: BotState::new(&cfg).await?, cfg, - bot_state: BotState { - accepted_nsfw: None, - fortune_templates, - }, - } + }) + } + + pub async fn reload(&mut self) -> Result<(), Error> { + let cfg = BotConfig::new(&self.args.cfg_path)?; + + self.cfg = cfg; + self.bot_state = BotState::new(&self.cfg).await?; + + Ok(()) } } diff --git a/src/discord/admin.rs b/src/discord/admin.rs new file mode 100644 index 0000000..1764c27 --- /dev/null +++ b/src/discord/admin.rs @@ -0,0 +1,22 @@ +use crate::{command, group, GlobalData}; +use serenity::client::Context; +use serenity::framework::standard::{Args, CommandResult}; +use serenity::model::channel::Message; + +#[group] +#[commands(reload)] +pub struct ADMIN; + +#[command] +#[owners_only] +#[only_in(guilds)] +async fn reload(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { + let mut data = ctx.data.write().await; + let global_data = data.get_mut::().unwrap(); + + global_data.reload().await?; + + msg.reply(&ctx.http, "Reload done ;)").await?; + + Ok(()) +} diff --git a/src/discord/album.rs b/src/discord/album.rs index e92d5b1..e64a070 100644 --- a/src/discord/album.rs +++ b/src/discord/album.rs @@ -1,6 +1,6 @@ use crate::config::AlbumConfig; use crate::error::Error; -use crate::{command, group, imgur, GlobalData}; +use crate::{command, group, GlobalData}; use serenity::client::Context; use serenity::framework::standard::{Args, CommandResult}; use serenity::model::channel::Message; @@ -38,19 +38,29 @@ async fn add_album(ctx: &Context, msg: &Message, mut args: Args) -> CommandResul let global_data = data.get_mut::().unwrap(); + let old_config = global_data.cfg.clone(); + global_data.cfg.albums.push(AlbumConfig { album_id, name: album_name.clone(), }); - global_data - .cfg - .save(&global_data.args.cfg_path) - .await - .unwrap(); + global_data.cfg.save(&global_data.args.cfg_path).await?; - msg.reply(&ctx.http, format!("{} album added!", album_name)) + if global_data.reload().await.is_err() { + global_data.cfg = old_config; + + global_data.cfg.save(&global_data.args.cfg_path).await?; + + msg.reply( + &ctx.http, + "Error adding album, check your link and try again", + ) .await?; + } else { + msg.reply(&ctx.http, format!("{} album added!", album_name)) + .await?; + } Ok(()) } @@ -72,11 +82,9 @@ async fn remove_album(ctx: &Context, msg: &Message, args: Args) -> CommandResult .albums .retain(|album| !album.name.eq_ignore_ascii_case(&album_name)); - global_data - .cfg - .save(&global_data.args.cfg_path) - .await - .unwrap(); + global_data.cfg.save(&global_data.args.cfg_path).await?; + + global_data.reload().await?; msg.reply(&ctx.http, format!("{} album removed!", album_name)) .await?; @@ -123,25 +131,12 @@ pub async fn parse_album( let data = ctx.data.read().await; let global_data = data.get::().unwrap(); - let album = global_data - .cfg - .albums - .iter() - .find(|album| album.name.to_lowercase() == album_name); - - if let Some(album) = album { - match imgur::get_image(album, global_data, tags).await { - Ok(image) => { - if let Some(image) = image { - msg.reply(&ctx.http, &image.link).await?; - } else { - msg.reply(&ctx.http, "No image found ;(").await?; - } - } - Err(_) => { - msg.reply(&ctx.http, "Unable to get album, try again later.") - .await?; - } + match global_data.bot_state.get_image(album_name, tags) { + Some(image) => { + msg.reply(&ctx.http, &image.link).await?; + } + None => { + msg.reply(&ctx.http, "No image ;(").await?; } }; diff --git a/src/discord/joke.rs b/src/discord/joke.rs index b0c5d12..ea98053 100644 --- a/src/discord/joke.rs +++ b/src/discord/joke.rs @@ -1,5 +1,4 @@ use crate::error::Error; -use crate::imgur::get_image; use crate::{command, group, GlobalData}; use rand::prelude::IteratorRandom; use rand::thread_rng; @@ -55,7 +54,7 @@ impl FortuneCtx { let mut random_image: HashMap = HashMap::new(); for album in &global_data.cfg.albums { - let image = get_image(album, global_data, Vec::new()).await?; + let image = global_data.bot_state.get_image(&album.name, Vec::new()); if let Some(image) = image { random_image.insert(album.name.clone(), image.link); diff --git a/src/discord/mod.rs b/src/discord/mod.rs index 1c74a16..063c237 100644 --- a/src/discord/mod.rs +++ b/src/discord/mod.rs @@ -1,3 +1,4 @@ +pub mod admin; pub mod album; pub mod celeryman; pub mod color; @@ -11,9 +12,10 @@ use serenity::framework::standard::{ }; use serenity::model::channel::Message; use serenity::model::id::UserId; -use serenity::model::prelude::Ready; +use serenity::model::prelude::{GuildId, Ready}; use serenity::prelude::EventHandler; use std::collections::HashSet; +use std::time::Duration; pub struct Handler; @@ -23,6 +25,21 @@ static ERROR_MSG: &str = #[async_trait] impl EventHandler for Handler { + async fn cache_ready(&self, ctx: Context, _guilds: Vec) { + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(60 * 60)).await; + { + println!("Reloading config..."); + let mut data = ctx.data.write().await; + let global_data = data.get_mut::().unwrap(); + + global_data.reload().await.unwrap(); + } + } + }); + } + async fn message(&self, ctx: Context, new_message: Message) { if new_message.author.bot { return; diff --git a/src/imgur/mod.rs b/src/imgur/mod.rs index 7ce8b54..ce0dfd2 100644 --- a/src/imgur/mod.rs +++ b/src/imgur/mod.rs @@ -1,7 +1,3 @@ -use crate::config::AlbumConfig; -use crate::error::Error; -use crate::GlobalData; -use rand::prelude::SliceRandom; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter}; @@ -75,34 +71,3 @@ pub async fn get_album_images(client_id: &str, album_hash: &str) -> Result, -) -> Result, Error> { - let album = get_album_images(&global_data.cfg.imgur_client_id, &album_config.album_id).await?; - let mut rng = rand::thread_rng(); - - let album = if tags.is_empty() { - album - } else { - album - .iter() - .filter(|img| { - for tag in &tags { - if let Some(desc) = &img.description { - if desc.to_lowercase().contains(&tag.to_lowercase()) { - return true; - } - } - } - - false - }) - .cloned() - .collect() - }; - - Ok(album.choose(&mut rng).cloned()) -} diff --git a/src/main.rs b/src/main.rs index b5c3a4b..8f19037 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,13 @@ async fn main() { } }; - let global_data = GlobalData::new(args, cfg); + let global_data = match GlobalData::new(args, cfg).await { + Ok(global_data) => global_data, + Err(err) => { + println!("Error parsing config: {}", err); + return; + } + }; let framework = StandardFramework::new() .configure(|c| c.with_whitespace(true).prefix("!").ignore_bots(true)) @@ -30,6 +36,7 @@ async fn main() { .group(&discord::album::ALBUM_GROUP) .group(&discord::celeryman::CELERYMAN_GROUP) .group(&discord::joke::JOKE_GROUP) + .group(&discord::admin::ADMIN_GROUP) .unrecognised_command(unrecognised_command_hook) .help(&discord::MY_HELP) .after(discord::after);