Added bot api

+ Right now can just play voices
+ Added commands to add api keys
+ clippy + fmt
This commit is contained in:
Joey Hines 2023-01-21 15:36:14 -07:00
parent 6896ee610c
commit 41213db8d4
Signed by: joeyahines
GPG Key ID: 995E531F7A569DDB
10 changed files with 396 additions and 39 deletions

140
Cargo.lock generated
View File

@ -126,6 +126,56 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "678c5130a507ae3a7c797f9a17393c14849300b8440eac47cdb90a5bdcb3a543"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "base64"
version = "0.13.1"
@ -606,6 +656,8 @@ dependencies = [
name = "fren"
version = "0.1.0"
dependencies = [
"axum",
"base64 0.21.0",
"config",
"j_db",
"json",
@ -616,6 +668,7 @@ dependencies = [
"reqwest",
"serde",
"serenity",
"sha3",
"songbird",
"structopt",
"tera",
@ -874,6 +927,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.8.0"
@ -1064,6 +1123,15 @@ dependencies = [
"serde",
]
[[package]]
name = "keccak"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768"
dependencies = [
"cpufeatures",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -1167,6 +1235,12 @@ dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
version = "2.5.0"
@ -2001,6 +2075,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.10"
@ -2094,6 +2177,16 @@ dependencies = [
"digest",
]
[[package]]
name = "sha3"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9"
dependencies = [
"digest",
"keccak",
]
[[package]]
name = "sharded-slab"
version = "0.1.4"
@ -2299,6 +2392,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "tempfile"
version = "3.3.0"
@ -2498,6 +2597,47 @@ dependencies = [
"serde",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
version = "0.3.2"

View File

@ -19,6 +19,9 @@ magick_rust = "0.17.0"
songbird = "0.3.0"
json = "0.12.4"
j_db = {git = "https://git.jojodev.com/joeyahines/j_db"}
axum = "0.6.3"
sha3 = "0.10.6"
base64 = "0.21.0"
[dependencies.serenity]
version = "0.11.5"

57
src/api/mod.rs Normal file
View File

@ -0,0 +1,57 @@
use crate::config::GlobalData;
use crate::discord::voices::speak;
use crate::models::api_key::Apikey;
use axum::extract::State;
use axum::{http::StatusCode, response::IntoResponse, routing::post, Json, Router};
use serde::Deserialize;
use serenity::prelude::Context;
pub async fn web_server(ctx: Context) {
let addr = {
let data = ctx.data.read().await;
let global_data = data.get::<GlobalData>().unwrap();
global_data.cfg.api_addr
};
let app = Router::new()
.route("/play", post(play_sound))
.with_state(ctx);
println!("Serving bot api on: {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(Deserialize)]
struct SoundPayload {
pub api_key: String,
pub voice: String,
pub phrase: String,
}
async fn play_sound(
State(ctx): State<Context>,
Json(payload): Json<SoundPayload>,
) -> impl IntoResponse {
let data = ctx.data.read().await;
let global_data = data.get::<GlobalData>().unwrap();
if let Some(api_key) = Apikey::find_key_from_secret(&global_data.db, &payload.api_key).unwrap()
{
if let Some(user_id) = api_key.user_id {
speak(
&ctx,
global_data.cfg.guild_id,
user_id,
&payload.voice,
&payload.phrase,
)
.await
.unwrap();
}
}
StatusCode::ACCEPTED
}

View File

@ -6,9 +6,10 @@ use config::{Config, File};
use j_db::database::Database;
use rand::prelude::SliceRandom;
use serde::{Deserialize, Serialize};
use serenity::model::prelude::UserId;
use serenity::model::prelude::{GuildId, UserId};
use serenity::prelude::TypeMapKey;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use structopt::StructOpt;
@ -30,6 +31,8 @@ pub struct BotConfig {
pub nft_path: PathBuf,
pub db_path: PathBuf,
pub admins: Vec<UserId>,
pub guild_id: GuildId,
pub api_addr: SocketAddr,
}
impl BotConfig {
@ -47,6 +50,7 @@ pub struct BotState {
pub accepted_nsfw: Option<UserId>,
pub albums: HashMap<String, Vec<Image>>,
pub bad_apple_running: bool,
pub speak_lock: Mutex<()>,
}
impl BotState {
@ -64,6 +68,7 @@ impl BotState {
accepted_nsfw: None,
albums,
bad_apple_running: false,
speak_lock: Mutex::new(()),
})
}

View File

@ -1,4 +1,5 @@
use crate::config::BotConfig;
use crate::models::api_key::Apikey;
use crate::{command, group, GlobalData};
use json::JsonValue;
use serenity::client::Context;
@ -8,7 +9,7 @@ use serenity::model::prelude::UserId;
use std::borrow::Cow;
#[group]
#[commands(reload, dump_db, load_db)]
#[commands(reload, dump_db, load_db, add_key)]
pub struct ADMIN;
pub fn is_admin(user_id: &UserId, cfg: &BotConfig) -> bool {
@ -96,3 +97,33 @@ async fn load_db(ctx: &Context, msg: &Message, _args: Args) -> CommandResult {
Ok(())
}
#[command]
async fn add_key(ctx: &Context, msg: &Message, args: Args) -> CommandResult {
let mut data = ctx.data.write().await;
let global_data = data.get_mut::<GlobalData>().unwrap();
if !is_admin(&msg.author.id, &global_data.cfg) {
return Ok(());
}
let (api_key, key) = if args.len() == 1 {
let user_id = args.parse::<UserId>()?;
let user = user_id.to_user(&ctx.http).await?;
Apikey::new(&format!("{}'s Key", user.name), Some(user_id))
} else {
Apikey::new(&format!("{}'s Key", msg.author.name), Some(msg.author.id))
};
global_data.db.insert::<Apikey>(api_key.clone())?;
let dm = msg.author.create_dm_channel(&ctx.http).await?;
dm.say(
&ctx.http,
format!("Key '{}' added. Api Key: {}", api_key.name, key),
)
.await?;
Ok(())
}

View File

@ -10,6 +10,7 @@ pub mod shop;
pub mod story;
pub mod voices;
use crate::api::web_server;
use crate::discord::fren_coin::give_coin;
use crate::discord::joke::random;
use crate::discord::shop::restock_shop;
@ -126,6 +127,8 @@ impl EventHandler for Handler {
OnlineStatus::Online,
)
.await;
tokio::spawn(async move { web_server(ctx).await });
}
}

View File

@ -2,12 +2,15 @@ use crate::{command, group, GlobalData};
use serenity::client::Context;
use serenity::framework::standard::{Args, CommandResult};
use serenity::model::channel::{AttachmentType, Message};
use serenity::model::id::UserId;
use serenity::model::prelude::GuildId;
use serenity::utils::MessageBuilder;
use songbird::driver::Bitrate;
use songbird::input;
use songbird::input::cached::Compressed;
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};
#[group]
@ -75,33 +78,58 @@ async fn find_voice(voice_path: &Path, name: &str) -> Result<Option<PathBuf>, to
Ok(None)
}
#[command]
#[only_in(guilds)]
#[min_args(1)]
async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let guild = msg.guild(&ctx.cache).unwrap();
let guild_id = guild.id;
#[derive(Debug)]
pub enum VoiceError {
VoiceNotFound(String),
WordNotFound(String),
NotInVoiceChannel,
Serenity(serenity::Error),
}
impl Display for VoiceError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
VoiceError::VoiceNotFound(v) => write!(f, "{} voice not found", v),
VoiceError::WordNotFound(w) => write!(f, "{} is not in dictionary", w),
VoiceError::Serenity(err) => write!(f, "Serenity error: {}", err),
VoiceError::NotInVoiceChannel => write!(f, "User not in voice channel"),
}
}
}
impl From<serenity::Error> for VoiceError {
fn from(value: serenity::Error) -> Self {
Self::Serenity(value)
}
}
impl std::error::Error for VoiceError {}
pub async fn speak(
ctx: &Context,
guild_id: GuildId,
user_id: UserId,
voice: &str,
phrase: &str,
) -> Result<(), VoiceError> {
let data = ctx.data.read().await;
let global_data = data.get::<GlobalData>().unwrap();
let voice = args.parse::<String>().unwrap();
let _ = global_data.bot_state.speak_lock.lock().await;
let voice_path = match find_voice(&global_data.cfg.voice_path, &voice).await? {
None => {
msg.reply(&ctx.http, format!("No voice found called '{}'", voice))
.await?;
return Ok(());
}
let voice_path = match find_voice(&global_data.cfg.voice_path, voice)
.await
.unwrap()
{
None => return Err(VoiceError::VoiceNotFound(voice.to_string())),
Some(voice_path) => voice_path,
};
args.advance();
let dict = get_voice_dictionary(&voice_path).await.unwrap();
let dict = get_voice_dictionary(&voice_path).await?;
let mut phrase = Vec::new();
while let Some(word) = &args.current() {
let mut sentence = Vec::new();
for word in phrase.split(' ') {
let word = word.to_lowercase();
let mut add_period = false;
let mut add_comma = false;
let word = if word.ends_with(',') {
@ -115,38 +143,31 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
};
if dict.contains_key(&word) {
phrase.push(word.to_string());
sentence.push(word.to_string());
} else {
msg.reply(
&ctx.http,
format!("The word '{}' is not in the dictionary", word),
)
.await?;
return Ok(());
return Err(VoiceError::WordNotFound(word));
}
if add_comma {
phrase.push("_comma".to_string());
sentence.push("_comma".to_string());
}
if add_period {
phrase.push("_period".to_string());
sentence.push("_period".to_string());
}
}
args.advance();
}
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap();
let channel_id = guild
.voice_states
.get(&msg.author.id)
.get(&user_id)
.and_then(|voice_state| voice_state.channel_id);
let connect_to = match channel_id {
Some(channel) => channel,
None => {
msg.reply(ctx, "You are not in a voice channel").await?;
return Ok(());
return Err(VoiceError::NotInVoiceChannel);
}
};
@ -157,11 +178,11 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let (handler_lock, success_reader) = manager.join(guild_id, connect_to).await;
success_reader?;
success_reader.unwrap();
let mut handler = handler_lock.lock().await;
for word in phrase {
for word in sentence {
let word_path = dict.get(&word).unwrap();
let audio_src = Compressed::new(
@ -175,12 +196,37 @@ async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let duration = audio_src.metadata.duration.unwrap();
let voice = handler.play_source(audio_src.into());
voice.set_volume(0.5)?;
voice.set_volume(0.5).unwrap();
tokio::time::sleep(duration).await;
}
handler.leave().await?;
handler.leave().await.unwrap();
Ok(())
}
#[command]
#[only_in(guilds)]
#[min_args(1)]
async fn say(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
let guild = msg.guild(&ctx.cache).unwrap();
let guild_id = guild.id;
let voice = args.parse::<String>()?;
args.advance();
let phrase = args.rest();
if let Err(err) = speak(ctx, guild_id, msg.author.id, &voice, phrase).await {
match err {
VoiceError::VoiceNotFound(_)
| VoiceError::WordNotFound(_)
| VoiceError::NotInVoiceChannel => {
msg.reply(&ctx.http, format!("Error: {}", err)).await?;
}
_ => return Err(err.into()),
}
}
Ok(())
}

View File

@ -1,3 +1,4 @@
mod api;
mod config;
mod discord;
mod error;

70
src/models/api_key.rs Normal file
View File

@ -0,0 +1,70 @@
use crate::error::Error;
use base64::{engine::general_purpose, Engine as _};
use j_db::database::Database;
use j_db::model::JdbModel;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serenity::model::id::UserId;
use sha3::digest::FixedOutput;
use sha3::{Digest, Sha3_256};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Apikey {
id: Option<u64>,
pub name: String,
pub hash: String,
pub user_id: Option<UserId>,
}
impl JdbModel for Apikey {
fn id(&self) -> Option<u64> {
self.id
}
fn set_id(&mut self, id: u64) {
self.id = Some(id)
}
fn tree() -> String {
"ApiKey".to_string()
}
}
impl Apikey {
pub fn new(name: &str, user_id: Option<UserId>) -> (Self, String) {
let secret: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect();
let hash = Self::hash_secret(&secret);
(
Self {
id: None,
name: name.to_string(),
hash,
user_id,
},
secret,
)
}
fn hash_secret(secret: &str) -> String {
let mut hasher = Sha3_256::default();
hasher.update(secret);
let hash = hasher.finalize_fixed().to_vec();
general_purpose::STANDARD_NO_PAD.encode(hash)
}
pub fn find_key_from_secret(db: &Database, secret: &str) -> Result<Option<Self>, Error> {
let hash = Self::hash_secret(secret);
Ok(db.filter(move |_, key: &Self| key.hash == hash)?.next())
}
}

View File

@ -1,4 +1,5 @@
pub mod album;
pub mod api_key;
pub mod insult_compliment;
pub mod motivation;
pub mod random;