127 lines
3.0 KiB
Rust
127 lines
3.0 KiB
Rust
use crate::error::Error;
|
|
use j_db::database::Database;
|
|
use j_db::model::JdbModel;
|
|
use rand::distr::weighted::WeightedIndex;
|
|
use rand::prelude::Distribution;
|
|
use rand::rng;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashSet;
|
|
|
|
const MAX_SCORE: u16 = 10;
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
pub struct Random {
|
|
id: Option<u64>,
|
|
pub response: String,
|
|
pub score: u16,
|
|
}
|
|
|
|
impl Random {
|
|
pub fn new(response: &str) -> Self {
|
|
Self {
|
|
id: None,
|
|
response: response.to_string(),
|
|
score: MAX_SCORE,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl JdbModel for Random {
|
|
fn id(&self) -> Option<u64> {
|
|
self.id
|
|
}
|
|
|
|
fn set_id(&mut self, id: u64) {
|
|
self.id = Some(id)
|
|
}
|
|
|
|
fn tree() -> String {
|
|
"random_opt".to_string()
|
|
}
|
|
|
|
fn check_unique(&self, other: &Self) -> bool {
|
|
!self.response.eq_ignore_ascii_case(&other.response)
|
|
}
|
|
}
|
|
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
|
pub struct RandomConfig {
|
|
id: Option<u64>,
|
|
pub name: String,
|
|
pub responses: HashSet<u64>,
|
|
}
|
|
|
|
impl JdbModel for RandomConfig {
|
|
fn id(&self) -> Option<u64> {
|
|
self.id
|
|
}
|
|
|
|
fn set_id(&mut self, id: u64) {
|
|
self.id = Some(id)
|
|
}
|
|
|
|
fn tree() -> String {
|
|
"randoms".to_string()
|
|
}
|
|
|
|
fn check_unique(&self, _other: &Self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
impl RandomConfig {
|
|
pub fn add_random(db: &Database, name: &str, response: &str) -> Result<(), Error> {
|
|
let mut randoms = match Self::get_random(db, name)? {
|
|
None => db.insert::<RandomConfig>(Self {
|
|
id: None,
|
|
name: name.to_string(),
|
|
responses: HashSet::new(),
|
|
})?,
|
|
Some(random) => random,
|
|
};
|
|
|
|
let random = Random::new(response);
|
|
|
|
let random = db.insert(random)?;
|
|
|
|
randoms.responses.insert(random.id().unwrap());
|
|
|
|
db.insert::<RandomConfig>(randoms)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn get_random(db: &Database, name: &str) -> Result<Option<Self>, Error> {
|
|
Ok(db
|
|
.filter(|_, random: &RandomConfig| random.name.eq_ignore_ascii_case(name))?
|
|
.next())
|
|
}
|
|
|
|
pub fn get_response(&self, db: &Database) -> Result<Option<String>, Error> {
|
|
let mut responses: Vec<Random> = db
|
|
.filter(|_, random: &Random| self.responses.contains(&random.id().unwrap()))?
|
|
.collect();
|
|
|
|
if !responses.iter().any(|r| r.score > 0) {
|
|
for response in &mut responses {
|
|
response.score = MAX_SCORE;
|
|
|
|
db.insert(response.clone())?;
|
|
}
|
|
}
|
|
|
|
if responses.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let dist = WeightedIndex::new(responses.iter().map(|r| r.score)).unwrap();
|
|
|
|
let mut resp = responses[dist.sample(&mut rng())].clone();
|
|
|
|
resp.score = resp.score.saturating_sub(1);
|
|
|
|
let resp = db.insert(resp)?;
|
|
|
|
Ok(Some(resp.response))
|
|
}
|
|
}
|