Removing rate limiting.

This commit is contained in:
Dessalines 2019-11-30 13:41:52 -08:00
parent 8634b1e2bb
commit edf6434854
1 changed files with 9 additions and 101 deletions

View File

@ -13,29 +13,12 @@ extern crate r2d2_sqlite;
use actix_files as fs; use actix_files as fs;
use actix_files::NamedFile; use actix_files::NamedFile;
use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; use actix_web::{web, App, HttpResponse, HttpServer};
use failure::Error; use failure::Error;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params; use rusqlite::params;
use std::collections::HashMap;
use std::env; use std::env;
use std::ops::Deref; use std::ops::Deref;
use std::sync::Mutex;
use std::time::SystemTime;
const RATE_LIMIT: i32 = 10;
const RATE_LIMIT_PER_SECOND: i32 = 60;
pub struct State {
rate_limits: Mutex<HashMap<String, RateLimitBucket>>,
pool: Mutex<r2d2::Pool<SqliteConnectionManager>>,
}
#[derive(Debug)]
pub struct RateLimitBucket {
last_checked: SystemTime,
allowance: f64,
}
fn main() { fn main() {
println!("Access me at {}", endpoint()); println!("Access me at {}", endpoint());
@ -43,17 +26,12 @@ fn main() {
let manager = SqliteConnectionManager::file(torrents_db_file()); let manager = SqliteConnectionManager::file(torrents_db_file());
let pool = r2d2::Pool::builder().max_size(15).build(manager).unwrap(); let pool = r2d2::Pool::builder().max_size(15).build(manager).unwrap();
let shared_data = web::Data::new(State {
rate_limits: Mutex::new(HashMap::new()),
pool: Mutex::new(pool),
});
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
.route("/", web::get().to(index)) .route("/", web::get().to(index))
.service(fs::Files::new("/static", front_end_dir())) .service(fs::Files::new("/static", front_end_dir()))
.register_data(shared_data.clone()) .data(pool.clone())
.route("/service/search", web::get().to(search)) .route("/service/search", web::get().to_async(search))
}) })
.bind(endpoint()) .bind(endpoint())
.unwrap() .unwrap()
@ -86,18 +64,9 @@ struct SearchQuery {
} }
fn search( fn search(
req: HttpRequest, db: web::Data<r2d2::Pool<SqliteConnectionManager>>,
data: web::Data<State>,
query: web::Query<SearchQuery>, query: web::Query<SearchQuery>,
) -> HttpResponse { ) -> HttpResponse {
let ip = req
.connection_info()
.remote()
.unwrap_or("127.0.0.1:12345")
.split(":")
.next()
.unwrap_or("127.0.0.1")
.to_string();
if query.q.is_empty() { if query.q.is_empty() {
return HttpResponse::BadRequest() return HttpResponse::BadRequest()
@ -106,18 +75,12 @@ fn search(
.body(format!("{{\"error\": \"{}\"}}", "Empty query".to_string())); .body(format!("{{\"error\": \"{}\"}}", "Empty query".to_string()));
} }
let conn = data.pool.lock().unwrap().get().unwrap(); let conn = db.get().unwrap();
match check_rate_limit_full(data, &ip, RATE_LIMIT, RATE_LIMIT_PER_SECOND) { HttpResponse::Ok()
Ok(_) => HttpResponse::Ok()
.header("Access-Control-Allow-Origin", "*") .header("Access-Control-Allow-Origin", "*")
.content_type("application/json") .content_type("application/json")
.body(search_query(query, conn).unwrap()), .body(search_query(query, conn).unwrap())
Err(e) => HttpResponse::BadRequest()
.header("Access-Control-Allow-Origin", "*")
.content_type("application/json")
.body(format!("{{\"error\": \"{}\"}}", e.to_string())),
}
} }
fn search_query( fn search_query(
@ -239,61 +202,6 @@ fn torrent_file_search(
Ok(files) Ok(files)
} }
fn check_rate_limit_full(
state: web::Data<State>,
ip: &str,
rate: i32,
per: i32,
) -> Result<(), Error> {
let mut rate_limits = state.rate_limits.lock().unwrap();
if rate_limits.get_mut(ip).is_none() {
rate_limits.insert(
ip.to_string(),
RateLimitBucket {
last_checked: SystemTime::now(),
allowance: -2f64,
},
);
}
if let Some(rate_limit) = rate_limits.get_mut(ip) {
// The initial value
if rate_limit.allowance == -2f64 {
rate_limit.allowance = rate as f64;
};
let current = SystemTime::now();
let time_passed = current
.duration_since(rate_limit.last_checked)
.unwrap()
.as_secs() as f64;
rate_limit.last_checked = current;
rate_limit.allowance += time_passed * (rate as f64 / per as f64);
if rate_limit.allowance > rate as f64 {
rate_limit.allowance = rate as f64;
}
if rate_limit.allowance < 1.0 {
println!(
"Rate limited IP: {}, time_passed: {}, allowance: {}",
&ip, time_passed, rate_limit.allowance
);
Err(format_err!(
"Too many requests for IP: {}. {} per {} seconds",
&ip,
rate,
per
))
.unwrap()
} else {
rate_limit.allowance -= 1.0;
Ok(())
}
} else {
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use time::PreciseTime; use time::PreciseTime;