diff --git a/Cargo.lock b/Cargo.lock index 61e3c26..7385dee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,6 +126,56 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678c5130a507ae3a7c797f9a17393c14849300b8440eac47cdb90a5bdcb3a543" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.13.1" @@ -606,6 +656,8 @@ dependencies = [ name = "fren" version = "0.1.0" dependencies = [ + "axum", + "base64 0.21.0", "config", "j_db", "json", @@ -616,6 +668,7 @@ dependencies = [ "reqwest", "serde", "serenity", + "sha3", "songbird", "structopt", "tera", @@ -874,6 +927,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -1064,6 +1123,15 @@ dependencies = [ "serde", ] +[[package]] +name = "keccak" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1167,6 +1235,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" + [[package]] name = "memchr" version = "2.5.0" @@ -2001,6 +2075,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.10" @@ -2094,6 +2177,16 @@ dependencies = [ "digest", ] +[[package]] +name = "sha3" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -2299,6 +2392,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "tempfile" version = "3.3.0" @@ -2498,6 +2597,47 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index f951f7b..9be9272 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,9 @@ magick_rust = "0.17.0" songbird = "0.3.0" json = "0.12.4" j_db = {git = "https://git.jojodev.com/joeyahines/j_db"} +axum = "0.6.3" +sha3 = "0.10.6" +base64 = "0.21.0" [dependencies.serenity] version = "0.11.5" diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..bbc3d9b --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,57 @@ +use crate::config::GlobalData; +use crate::discord::voices::speak; +use crate::models::api_key::Apikey; +use axum::extract::State; +use axum::{http::StatusCode, response::IntoResponse, routing::post, Json, Router}; +use serde::Deserialize; +use serenity::prelude::Context; + +pub async fn web_server(ctx: Context) { + let addr = { + let data = ctx.data.read().await; + let global_data = data.get::().unwrap(); + global_data.cfg.api_addr + }; + + let app = Router::new() + .route("/play", post(play_sound)) + .with_state(ctx); + + println!("Serving bot api on: {}", addr); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +#[derive(Deserialize)] +struct SoundPayload { + pub api_key: String, + pub voice: String, + pub phrase: String, +} + +async fn play_sound( + State(ctx): State, + Json(payload): Json, +) -> impl IntoResponse { + let data = ctx.data.read().await; + let global_data = data.get::().unwrap(); + + if let Some(api_key) = Apikey::find_key_from_secret(&global_data.db, &payload.api_key).unwrap() + { + if let Some(user_id) = api_key.user_id { + speak( + &ctx, + global_data.cfg.guild_id, + user_id, + &payload.voice, + &payload.phrase, + ) + .await + .unwrap(); + } + } + + StatusCode::ACCEPTED +} diff --git a/src/config.rs b/src/config.rs index ae3a18c..028aba1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,9 +6,10 @@ use config::{Config, File}; use j_db::database::Database; use rand::prelude::SliceRandom; use serde::{Deserialize, Serialize}; -use serenity::model::prelude::UserId; +use serenity::model::prelude::{GuildId, UserId}; use serenity::prelude::TypeMapKey; use std::collections::HashMap; +use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::sync::Arc; use structopt::StructOpt; @@ -30,6 +31,8 @@ pub struct BotConfig { pub nft_path: PathBuf, pub db_path: PathBuf, pub admins: Vec, + pub guild_id: GuildId, + pub api_addr: SocketAddr, } impl BotConfig { @@ -47,6 +50,7 @@ pub struct BotState { pub accepted_nsfw: Option, pub albums: HashMap>, pub bad_apple_running: bool, + pub speak_lock: Mutex<()>, } impl BotState { @@ -64,6 +68,7 @@ impl BotState { accepted_nsfw: None, albums, bad_apple_running: false, + speak_lock: Mutex::new(()), }) } diff --git a/src/discord/admin.rs b/src/discord/admin.rs index 66aebab..009d1d5 100644 --- a/src/discord/admin.rs +++ b/src/discord/admin.rs @@ -1,4 +1,5 @@ use crate::config::BotConfig; +use crate::models::api_key::Apikey; use crate::{command, group, GlobalData}; use json::JsonValue; use serenity::client::Context; @@ -8,7 +9,7 @@ use serenity::model::prelude::UserId; use std::borrow::Cow; #[group] -#[commands(reload, dump_db, load_db)] +#[commands(reload, dump_db, load_db, add_key)] pub struct ADMIN; pub fn is_admin(user_id: &UserId, cfg: &BotConfig) -> bool { @@ -96,3 +97,33 @@ async fn load_db(ctx: &Context, msg: &Message, _args: Args) -> CommandResult { Ok(()) } + +#[command] +async fn add_key(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + let mut data = ctx.data.write().await; + let global_data = data.get_mut::().unwrap(); + + if !is_admin(&msg.author.id, &global_data.cfg) { + return Ok(()); + } + + let (api_key, key) = if args.len() == 1 { + let user_id = args.parse::()?; + let user = user_id.to_user(&ctx.http).await?; + Apikey::new(&format!("{}'s Key", user.name), Some(user_id)) + } else { + Apikey::new(&format!("{}'s Key", msg.author.name), Some(msg.author.id)) + }; + + global_data.db.insert::(api_key.clone())?; + + let dm = msg.author.create_dm_channel(&ctx.http).await?; + + dm.say( + &ctx.http, + format!("Key '{}' added. Api Key: {}", api_key.name, key), + ) + .await?; + + Ok(()) +} diff --git a/src/discord/mod.rs b/src/discord/mod.rs index 23d3f9e..319849c 100644 --- a/src/discord/mod.rs +++ b/src/discord/mod.rs @@ -10,6 +10,7 @@ pub mod shop; pub mod story; pub mod voices; +use crate::api::web_server; use crate::discord::fren_coin::give_coin; use crate::discord::joke::random; use crate::discord::shop::restock_shop; @@ -126,6 +127,8 @@ impl EventHandler for Handler { OnlineStatus::Online, ) .await; + + tokio::spawn(async move { web_server(ctx).await }); } } diff --git a/src/discord/voices.rs b/src/discord/voices.rs index 99d8ef7..36a36e1 100644 --- a/src/discord/voices.rs +++ b/src/discord/voices.rs @@ -2,12 +2,15 @@ use crate::{command, group, GlobalData}; use serenity::client::Context; use serenity::framework::standard::{Args, CommandResult}; use serenity::model::channel::{AttachmentType, Message}; +use serenity::model::id::UserId; +use serenity::model::prelude::GuildId; use serenity::utils::MessageBuilder; use songbird::driver::Bitrate; use songbird::input; use songbird::input::cached::Compressed; use std::borrow::Cow; use std::collections::HashMap; +use std::fmt::{Display, Formatter}; use std::path::{Path, PathBuf}; #[group] @@ -75,33 +78,58 @@ async fn find_voice(voice_path: &Path, name: &str) -> Result, to Ok(None) } -#[command] -#[only_in(guilds)] -#[min_args(1)] -async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { - let guild = msg.guild(&ctx.cache).unwrap(); - let guild_id = guild.id; +#[derive(Debug)] +pub enum VoiceError { + VoiceNotFound(String), + WordNotFound(String), + NotInVoiceChannel, + Serenity(serenity::Error), +} +impl Display for VoiceError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + VoiceError::VoiceNotFound(v) => write!(f, "{} voice not found", v), + VoiceError::WordNotFound(w) => write!(f, "{} is not in dictionary", w), + VoiceError::Serenity(err) => write!(f, "Serenity error: {}", err), + VoiceError::NotInVoiceChannel => write!(f, "User not in voice channel"), + } + } +} + +impl From for VoiceError { + fn from(value: serenity::Error) -> Self { + Self::Serenity(value) + } +} + +impl std::error::Error for VoiceError {} + +pub async fn speak( + ctx: &Context, + guild_id: GuildId, + user_id: UserId, + voice: &str, + phrase: &str, +) -> Result<(), VoiceError> { let data = ctx.data.read().await; let global_data = data.get::().unwrap(); - let voice = args.parse::().unwrap(); + let _ = global_data.bot_state.speak_lock.lock().await; - let voice_path = match find_voice(&global_data.cfg.voice_path, &voice).await? { - None => { - msg.reply(&ctx.http, format!("No voice found called '{}'", voice)) - .await?; - return Ok(()); - } + let voice_path = match find_voice(&global_data.cfg.voice_path, voice) + .await + .unwrap() + { + None => return Err(VoiceError::VoiceNotFound(voice.to_string())), Some(voice_path) => voice_path, }; - args.advance(); + let dict = get_voice_dictionary(&voice_path).await.unwrap(); - let dict = get_voice_dictionary(&voice_path).await?; - - let mut phrase = Vec::new(); - while let Some(word) = &args.current() { + let mut sentence = Vec::new(); + for word in phrase.split(' ') { + let word = word.to_lowercase(); let mut add_period = false; let mut add_comma = false; let word = if word.ends_with(',') { @@ -115,38 +143,31 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { }; if dict.contains_key(&word) { - phrase.push(word.to_string()); + sentence.push(word.to_string()); } else { - msg.reply( - &ctx.http, - format!("The word '{}' is not in the dictionary", word), - ) - .await?; - return Ok(()); + return Err(VoiceError::WordNotFound(word)); } if add_comma { - phrase.push("_comma".to_string()); + sentence.push("_comma".to_string()); } if add_period { - phrase.push("_period".to_string()); + sentence.push("_period".to_string()); } - - args.advance(); } + let guild = guild_id.to_guild_cached(&ctx.cache).unwrap(); + let channel_id = guild .voice_states - .get(&msg.author.id) + .get(&user_id) .and_then(|voice_state| voice_state.channel_id); let connect_to = match channel_id { Some(channel) => channel, None => { - msg.reply(ctx, "You are not in a voice channel").await?; - - return Ok(()); + return Err(VoiceError::NotInVoiceChannel); } }; @@ -157,11 +178,11 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let (handler_lock, success_reader) = manager.join(guild_id, connect_to).await; - success_reader?; + success_reader.unwrap(); let mut handler = handler_lock.lock().await; - for word in phrase { + for word in sentence { let word_path = dict.get(&word).unwrap(); let audio_src = Compressed::new( @@ -175,12 +196,37 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let duration = audio_src.metadata.duration.unwrap(); let voice = handler.play_source(audio_src.into()); - voice.set_volume(0.5)?; + voice.set_volume(0.5).unwrap(); tokio::time::sleep(duration).await; } - handler.leave().await?; + handler.leave().await.unwrap(); + + Ok(()) +} + +#[command] +#[only_in(guilds)] +#[min_args(1)] +async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let guild = msg.guild(&ctx.cache).unwrap(); + let guild_id = guild.id; + + let voice = args.parse::()?; + args.advance(); + let phrase = args.rest(); + + if let Err(err) = speak(ctx, guild_id, msg.author.id, &voice, phrase).await { + match err { + VoiceError::VoiceNotFound(_) + | VoiceError::WordNotFound(_) + | VoiceError::NotInVoiceChannel => { + msg.reply(&ctx.http, format!("Error: {}", err)).await?; + } + _ => return Err(err.into()), + } + } Ok(()) } diff --git a/src/main.rs b/src/main.rs index fbb2081..3a541aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +mod api; mod config; mod discord; mod error; diff --git a/src/models/api_key.rs b/src/models/api_key.rs new file mode 100644 index 0000000..788f8ef --- /dev/null +++ b/src/models/api_key.rs @@ -0,0 +1,70 @@ +use crate::error::Error; +use base64::{engine::general_purpose, Engine as _}; +use j_db::database::Database; +use j_db::model::JdbModel; +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; +use serde::{Deserialize, Serialize}; +use serenity::model::id::UserId; +use sha3::digest::FixedOutput; +use sha3::{Digest, Sha3_256}; + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Apikey { + id: Option, + pub name: String, + pub hash: String, + pub user_id: Option, +} + +impl JdbModel for Apikey { + fn id(&self) -> Option { + self.id + } + + fn set_id(&mut self, id: u64) { + self.id = Some(id) + } + + fn tree() -> String { + "ApiKey".to_string() + } +} + +impl Apikey { + pub fn new(name: &str, user_id: Option) -> (Self, String) { + let secret: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(64) + .map(char::from) + .collect(); + + let hash = Self::hash_secret(&secret); + + ( + Self { + id: None, + name: name.to_string(), + hash, + user_id, + }, + secret, + ) + } + + fn hash_secret(secret: &str) -> String { + let mut hasher = Sha3_256::default(); + + hasher.update(secret); + + let hash = hasher.finalize_fixed().to_vec(); + + general_purpose::STANDARD_NO_PAD.encode(hash) + } + + pub fn find_key_from_secret(db: &Database, secret: &str) -> Result, Error> { + let hash = Self::hash_secret(secret); + + Ok(db.filter(move |_, key: &Self| key.hash == hash)?.next()) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index a62659c..aa208c1 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,4 +1,5 @@ pub mod album; +pub mod api_key; pub mod insult_compliment; pub mod motivation; pub mod random;