|
@@ -1,5 +1,5 @@
|
|
|
// Stract is an open source web search engine.
|
|
|
-// Copyright (C) 2023 Stract ApS
|
|
|
+// Copyright (C) 2024 Stract ApS
|
|
|
//
|
|
|
// This program is free software: you can redistribute it and/or modify
|
|
|
// it under the terms of the GNU Affero General Public License as
|
|
@@ -55,32 +55,10 @@ use crate::{query, webgraph, Result};
|
|
|
use self::sidebar::SidebarManager;
|
|
|
use self::widget::WidgetManager;
|
|
|
|
|
|
-use super::{distributed, live, SearchQuery, SearchResult, WebsitesResult};
|
|
|
+use super::{distributed, ScoredWebpagePointer, SearchQuery, SearchResult, WebsitesResult};
|
|
|
|
|
|
const NUM_PIPELINE_RANKING_RESULTS: usize = 300;
|
|
|
|
|
|
-#[derive(Clone)]
|
|
|
-pub enum ScoredWebpagePointer {
|
|
|
- Normal(distributed::ScoredWebpagePointer),
|
|
|
- Live(live::ScoredWebpagePointer),
|
|
|
-}
|
|
|
-
|
|
|
-impl ScoredWebpagePointer {
|
|
|
- pub fn as_ranking(&self) -> &RecallRankingWebpage {
|
|
|
- match self {
|
|
|
- ScoredWebpagePointer::Normal(p) => &p.website,
|
|
|
- ScoredWebpagePointer::Live(p) => &p.website,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- pub fn as_ranking_mut(&mut self) -> &mut RecallRankingWebpage {
|
|
|
- match self {
|
|
|
- ScoredWebpagePointer::Normal(p) => &mut p.website,
|
|
|
- ScoredWebpagePointer::Live(p) => &mut p.website,
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
impl RankableWebpage for ScoredWebpagePointer {
|
|
|
fn set_raw_score(&mut self, score: f64) {
|
|
|
self.as_ranking_mut().set_raw_score(score);
|
|
@@ -103,17 +81,11 @@ impl RankableWebpage for ScoredWebpagePointer {
|
|
|
}
|
|
|
|
|
|
fn signals(&self) -> &EnumMap<SignalEnum, SignalCalculation> {
|
|
|
- match self {
|
|
|
- ScoredWebpagePointer::Normal(p) => p.website.signals(),
|
|
|
- ScoredWebpagePointer::Live(p) => p.website.signals(),
|
|
|
- }
|
|
|
+ self.as_ranking().signals()
|
|
|
}
|
|
|
|
|
|
fn signals_mut(&mut self) -> &mut EnumMap<SignalEnum, SignalCalculation> {
|
|
|
- match self {
|
|
|
- ScoredWebpagePointer::Normal(p) => p.website.signals_mut(),
|
|
|
- ScoredWebpagePointer::Live(p) => p.website.signals_mut(),
|
|
|
- }
|
|
|
+ self.as_ranking_mut().signals_mut()
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -243,10 +215,9 @@ where
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-pub struct ApiSearcher<S, L, G> {
|
|
|
+pub struct ApiSearcher<S, G> {
|
|
|
distributed_searcher: Arc<S>,
|
|
|
sidebar_manager: Option<SidebarManager>,
|
|
|
- live_searcher: Option<L>,
|
|
|
cross_encoder: Option<Arc<CrossEncoderModel>>,
|
|
|
lambda_model: Option<Arc<LambdaMART>>,
|
|
|
dual_encoder: Option<Arc<DualEncoder>>,
|
|
@@ -257,10 +228,9 @@ pub struct ApiSearcher<S, L, G> {
|
|
|
webgraph: Option<G>,
|
|
|
}
|
|
|
|
|
|
-impl<S, L, G> ApiSearcher<S, L, G>
|
|
|
+impl<S, G> ApiSearcher<S, G>
|
|
|
where
|
|
|
S: distributed::SearchClient,
|
|
|
- L: live::SearchClient,
|
|
|
G: Graph,
|
|
|
{
|
|
|
pub async fn new<C>(
|
|
@@ -284,7 +254,6 @@ where
|
|
|
Self {
|
|
|
distributed_searcher: dist_searcher,
|
|
|
sidebar_manager,
|
|
|
- live_searcher: None,
|
|
|
cross_encoder: None,
|
|
|
lambda_model: None,
|
|
|
dual_encoder: None,
|
|
@@ -298,11 +267,6 @@ where
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- pub fn with_live(mut self, live_searcher: L) -> Self {
|
|
|
- self.live_searcher = Some(live_searcher);
|
|
|
- self
|
|
|
- }
|
|
|
-
|
|
|
pub fn with_cross_encoder(mut self, cross_encoder: CrossEncoderModel) -> Self {
|
|
|
self.cross_encoder = Some(Arc::new(cross_encoder));
|
|
|
self
|
|
@@ -436,64 +400,9 @@ where
|
|
|
query: &str,
|
|
|
top_websites: &[ScoredWebpagePointer],
|
|
|
) -> Vec<PrecisionRankingWebpage> {
|
|
|
- let normal: Vec<_> = top_websites
|
|
|
- .iter()
|
|
|
- .enumerate()
|
|
|
- .filter_map(|(i, pointer)| {
|
|
|
- if let ScoredWebpagePointer::Normal(p) = pointer {
|
|
|
- Some((i, p.clone()))
|
|
|
- } else {
|
|
|
- None
|
|
|
- }
|
|
|
- })
|
|
|
- .collect();
|
|
|
-
|
|
|
- let live: Vec<_> = top_websites
|
|
|
- .iter()
|
|
|
- .enumerate()
|
|
|
- .filter_map(|(i, pointer)| {
|
|
|
- if let ScoredWebpagePointer::Live(p) = pointer {
|
|
|
- Some((i, p.clone()))
|
|
|
- } else {
|
|
|
- None
|
|
|
- }
|
|
|
- })
|
|
|
- .collect();
|
|
|
-
|
|
|
- let (retrieved_normal, retrieved_live) = tokio::join!(
|
|
|
- self.distributed_searcher.retrieve_webpages(&normal, query),
|
|
|
- self.retrieve_webpages_from_live(&live, query),
|
|
|
- );
|
|
|
-
|
|
|
- let mut retrieved_webpages: Vec<_> =
|
|
|
- retrieved_normal.into_iter().chain(retrieved_live).collect();
|
|
|
- retrieved_webpages.sort_by(|(a, _), (b, _)| a.cmp(b));
|
|
|
-
|
|
|
- retrieved_webpages
|
|
|
- .into_iter()
|
|
|
- .map(|(_, webpage)| webpage)
|
|
|
- .collect::<Vec<_>>()
|
|
|
- }
|
|
|
-
|
|
|
- async fn search_initial_from_live(
|
|
|
- &self,
|
|
|
- query: &SearchQuery,
|
|
|
- ) -> Option<Vec<live::InitialSearchResultShard>> {
|
|
|
- match &self.live_searcher {
|
|
|
- Some(searcher) => Some(searcher.search_initial(query).await),
|
|
|
- None => None,
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- async fn retrieve_webpages_from_live(
|
|
|
- &self,
|
|
|
- pointers: &[(usize, live::ScoredWebpagePointer)],
|
|
|
- query: &str,
|
|
|
- ) -> Vec<(usize, PrecisionRankingWebpage)> {
|
|
|
- match &self.live_searcher {
|
|
|
- Some(searcher) => searcher.retrieve_webpages(pointers, query).await,
|
|
|
- None => vec![],
|
|
|
- }
|
|
|
+ self.distributed_searcher
|
|
|
+ .retrieve_webpages(top_websites, query)
|
|
|
+ .await
|
|
|
}
|
|
|
|
|
|
async fn inbound_vecs(&self, ids: &[webgraph::NodeID]) -> Vec<bitvec_similarity::BitVec> {
|
|
@@ -507,7 +416,6 @@ where
|
|
|
&self,
|
|
|
query: &SearchQuery,
|
|
|
initial_results: Vec<distributed::InitialSearchResultShard>,
|
|
|
- live_results: Vec<live::InitialSearchResultShard>,
|
|
|
) -> (Vec<ScoredWebpagePointer>, bool) {
|
|
|
let mut collector =
|
|
|
BucketCollector::new(NUM_PIPELINE_RANKING_RESULTS, self.collector_config.clone());
|
|
@@ -518,17 +426,7 @@ where
|
|
|
.map(|r| *r.host_id())
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
- let live_host_nodes = live_results
|
|
|
- .iter()
|
|
|
- .flat_map(|r| r.local_result.websites.iter())
|
|
|
- .map(|r| *r.host_id())
|
|
|
- .collect::<Vec<_>>();
|
|
|
-
|
|
|
- let host_nodes = initial_host_nodes
|
|
|
- .into_iter()
|
|
|
- .chain(live_host_nodes)
|
|
|
- .unique()
|
|
|
- .collect::<Vec<_>>();
|
|
|
+ let host_nodes = initial_host_nodes.into_iter().unique().collect::<Vec<_>>();
|
|
|
|
|
|
let inbound_vecs = if !query.fetch_backlinks() {
|
|
|
HashMap::default()
|
|
@@ -555,26 +453,6 @@ where
|
|
|
shard: result.shard,
|
|
|
};
|
|
|
|
|
|
- let pointer = ScoredWebpagePointer::Normal(pointer);
|
|
|
-
|
|
|
- collector.insert(pointer);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- for result in live_results {
|
|
|
- num_results += result.local_result.websites.len();
|
|
|
- for website in result.local_result.websites {
|
|
|
- let inbound = inbound_vecs
|
|
|
- .get(website.host_id())
|
|
|
- .cloned()
|
|
|
- .unwrap_or_default();
|
|
|
- let pointer = live::ScoredWebpagePointer {
|
|
|
- website: RecallRankingWebpage::new(website, inbound),
|
|
|
- shard_id: result.shard_id,
|
|
|
- };
|
|
|
-
|
|
|
- let pointer = ScoredWebpagePointer::Live(pointer);
|
|
|
-
|
|
|
collector.insert(pointer);
|
|
|
}
|
|
|
}
|
|
@@ -648,7 +526,7 @@ where
|
|
|
.map(|result| result.local_result.num_websites)
|
|
|
.fold(approx_count::Count::Exact(0), |acc, count| acc + count);
|
|
|
|
|
|
- let (combined, _) = self.combine_results(query, results, vec![]).await;
|
|
|
+ let (combined, _) = self.combine_results(query, results).await;
|
|
|
let combined: Vec<_> = combined.into_iter().take(query.num_results).collect();
|
|
|
|
|
|
let mut retrieved_webpages: Vec<_> = self
|
|
@@ -696,24 +574,17 @@ where
|
|
|
..query.clone()
|
|
|
};
|
|
|
|
|
|
- let (initial_results, live_results) = tokio::join!(
|
|
|
- self.distributed_searcher.search_initial(&search_query),
|
|
|
- self.search_initial_from_live(&search_query),
|
|
|
- );
|
|
|
+ let initial_results = self
|
|
|
+ .distributed_searcher
|
|
|
+ .search_initial(&search_query)
|
|
|
+ .await;
|
|
|
|
|
|
let num_docs = initial_results
|
|
|
.iter()
|
|
|
.map(|result| result.local_result.num_websites)
|
|
|
- .chain(live_results.iter().flat_map(|results| {
|
|
|
- results
|
|
|
- .iter()
|
|
|
- .map(|result| result.local_result.num_websites)
|
|
|
- }))
|
|
|
.fold(approx_count::Count::Exact(0), |acc, count| acc + count);
|
|
|
|
|
|
- let (top_websites, has_more_results) = self
|
|
|
- .combine_results(query, initial_results, live_results.unwrap_or_default())
|
|
|
- .await;
|
|
|
+ let (top_websites, has_more_results) = self.combine_results(query, initial_results).await;
|
|
|
|
|
|
let inbound_scorer = self.inbound_scorer(query).await;
|
|
|
|