ferret/core/src/main.rs

216 lines
5.2 KiB
Rust

#[macro_use]
extern crate log;
use ammonia::clean;
use axum::{
body::Bytes,
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use chrono::{DateTime, NaiveDateTime, Utc};
use fuzzy_matcher::skim::SkimMatcherV2;
use fuzzy_matcher::FuzzyMatcher;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use sqlx::sqlite::SqlitePool;
use std::collections::BTreeMap;
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use url::Url;
use whatlang::{detect_lang, Lang};
struct AppState {
pool: SqlitePool,
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
let pool = SqlitePool::connect(&env::var("DATABASE_URL").unwrap())
.await
.unwrap();
// update_index(&pool).await;
let shared_state = Arc::new(AppState { pool: pool });
let app = Router::new()
// `GET /` goes to `root`
.route("/", get(root))
.route("/api/search", get(search))
.with_state(shared_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn root() -> &'static str {
"Hello, World!"
}
#[derive(Deserialize)]
struct SearchQuery {
language: String,
include: String,
ignore: Option<Vec<String>>,
option: SearchType,
}
#[derive(Deserialize)]
enum SearchType {
Fuzzy,
Regex,
Sql,
}
#[derive(Serialize)]
struct SearchResult {
url: Url,
size: i64,
title: String,
summary: String,
last_updated: DateTime<Utc>,
}
async fn fuzzy_search(
title: &str,
summary: &str,
url: &str,
last_updated: i64,
size: i64,
query: &SearchQuery,
) -> Option<(i64, SearchResult)> {
let mut score = 0;
let matcher = SkimMatcherV2::default();
let t_match = matcher.fuzzy_match(title, &query.include);
let s_match = matcher.fuzzy_match(summary, &query.include);
let u_match = matcher.fuzzy_match(url, &query.include);
if t_match.is_some() {
score += t_match.unwrap();
}
if s_match.is_some() {
score += s_match.unwrap() / 2;
}
if u_match.is_some() {
score += u_match.unwrap() / 2;
}
if score > 5 {
let timestamp = DateTime::<Utc>::from_utc(
NaiveDateTime::from_timestamp_opt(last_updated, 0).unwrap(),
Utc,
);
return Some((
score,
SearchResult {
url: Url::parse(url).unwrap(),
size: size,
title: title.to_string(),
summary: summary.to_string(),
last_updated: timestamp,
},
));
}
return None;
}
async fn search(
State(state): State<Arc<AppState>>,
Json(query): Json<SearchQuery>,
) -> Json<BTreeMap<i64, SearchResult>> {
let mut conn = state.pool.acquire().await.unwrap();
let list = sqlx::query!(
r#"
SELECT title, summary, url, content, last_updated, clicks, size
FROM search_index
WHERE language = ?1
ORDER BY last_updated
"#,
query.language
)
.fetch_all(&mut *conn)
.await
.unwrap();
let mut results = BTreeMap::new();
for res in list {
let mut is_match = false;
match query.option {
SearchType::Fuzzy => {
match fuzzy_search(
&res.title,
&res.summary,
&res.url,
res.last_updated,
res.size,
&query,
)
.await
{
Some((s, r)) => results.insert(s, r),
None => None,
};
}
_ => {}
};
}
return Json(results);
}
async fn update_index(pool: &SqlitePool) {
let mut conn = pool.acquire().await.unwrap();
let crawled = sqlx::query!(
r#"
SELECT last_fetched, url, body
FROM crawled_urls
ORDER BY last_fetched
"#
)
.fetch_all(&mut *conn)
.await
.unwrap();
for res in crawled {
let size = std::mem::size_of_val(&res.body) as u32;
let lang = detect_lang(&res.body).unwrap().code();
let document = Html::parse_document(&res.body);
let title_selector = Selector::parse("title").unwrap();
let title = match document.select(&title_selector).next() {
Some(v) => v.inner_html(),
None => res.url.clone(),
};
let desc_selector = Selector::parse("p").unwrap();
let summary = match document.select(&desc_selector).next() {
Some(v) => v.inner_html(),
None => String::new(),
};
let id = sqlx::query!(
r#"
REPLACE INTO search_index ( url, size, language, title, summary, content, last_updated )
VALUES ( ?1, ?2, ?3, ?4, ?5, ?6, ?7 )
"#,
res.url,
size,
lang,
title,
summary,
res.body,
res.last_fetched,
)
.execute(&mut *conn)
.await
.unwrap()
.last_insert_rowid();
}
}