Sfoglia il codice sorgente

optimise shortest path to use exact changed nodes (stored in hashset) if there are very few updated nodes

Mikkel Denker 8 mesi fa
parent
commit
774dbd87fb

+ 5 - 0
crates/core/src/ampc/dht/value.rs

@@ -53,6 +53,9 @@ impl ValueTrait for HarmonicMeta {}
 type ShortestPathMeta = crate::entrypoint::ampc::shortest_path::Meta;
 impl ValueTrait for ShortestPathMeta {}
 
+type ShortestPathChangedNodes = crate::entrypoint::ampc::shortest_path::UpdatedNodes;
+impl ValueTrait for ShortestPathChangedNodes {}
+
 impl ValueTrait for U64BloomFilter {}
 
 type Unit = ();
@@ -75,6 +78,7 @@ pub enum Value {
     HyperLogLog128(HyperLogLog128),
     HarmonicMeta(HarmonicMeta),
     ShortestPathMeta(ShortestPathMeta),
+    ShortestPathChangedNodes(ShortestPathChangedNodes),
     U64BloomFilter(U64BloomFilter),
     Unit(Unit),
 }
@@ -113,5 +117,6 @@ impl_from_to_value!(HyperLogLog64, HyperLogLog64);
 impl_from_to_value!(HyperLogLog128, HyperLogLog128);
 impl_from_to_value!(HarmonicMeta, HarmonicMeta);
 impl_from_to_value!(ShortestPathMeta, ShortestPathMeta);
+impl_from_to_value!(ShortestPathChangedNodes, ShortestPathChangedNodes);
 impl_from_to_value!(U64BloomFilter, U64BloomFilter);
 impl_from_to_value!(Unit, Unit);

+ 127 - 47
crates/core/src/entrypoint/ampc/shortest_path/mapper.rs

@@ -16,16 +16,19 @@
 
 use std::sync::{atomic::AtomicBool, Arc, Mutex};
 
-use bloom::U64BloomFilter;
 use rustc_hash::FxHashMap;
 
-use super::{DhtTable as _, Mapper, Meta, ShortestPathJob, ShortestPathTables};
+use super::{
+    updated_nodes::{UpdatedNodes, UpdatedNodesKind},
+    worker::ShortestPathWorker,
+    DhtTable as _, Mapper, Meta, ShortestPathJob, ShortestPathTables,
+};
 use crate::{
     ampc::{
         dht::{U64Min, UpsertAction},
         DhtConn,
     },
-    webgraph,
+    webgraph::{self, query},
     webpage::html::links::RelFlags,
 };
 
@@ -84,7 +87,7 @@ impl ShortestPathMapper {
 
     fn map_batch(
         batch: &[webgraph::SmallEdge],
-        new_changed_nodes: &Mutex<U64BloomFilter>,
+        new_changed_nodes: &Mutex<UpdatedNodes>,
         round_had_changes: &AtomicBool,
         dht: &DhtConn<ShortestPathTables>,
     ) {
@@ -93,11 +96,94 @@ impl ShortestPathMapper {
 
         for (node, action) in updates {
             if action.is_changed() {
-                new_changed_nodes.insert_u128(node.as_u128());
+                new_changed_nodes.add(node);
                 round_had_changes.store(true, std::sync::atomic::Ordering::Relaxed);
             }
         }
     }
+
+    fn relax_all_edges(
+        worker: &ShortestPathWorker,
+        changed_nodes: &UpdatedNodes,
+        new_changed_nodes: &Mutex<UpdatedNodes>,
+        round_had_changes: &AtomicBool,
+        dht: &DhtConn<ShortestPathTables>,
+    ) {
+        let pool = rayon::ThreadPoolBuilder::new().build().unwrap();
+        pool.scope(|s| {
+            let mut batch = Vec::with_capacity(BATCH_SIZE);
+
+            for edge in worker.graph().page_edges() {
+                if edge.rel_flags.intersects(*SKIPPED_REL) {
+                    continue;
+                }
+
+                if changed_nodes.contains(edge.from) {
+                    batch.push(edge);
+                }
+
+                if batch.len() >= BATCH_SIZE {
+                    let update_batch = batch.clone();
+                    s.spawn(move |_| {
+                        Self::map_batch(&update_batch, new_changed_nodes, round_had_changes, dht)
+                    });
+                    batch.clear();
+                }
+            }
+
+            if !batch.is_empty() {
+                Self::map_batch(&batch, new_changed_nodes, round_had_changes, dht);
+            }
+        });
+    }
+
+    fn relax_exact_edges(
+        worker: &ShortestPathWorker,
+        changed_nodes: &UpdatedNodes,
+        exact_changed_nodes: &[webgraph::NodeID],
+        new_changed_nodes: &Mutex<UpdatedNodes>,
+        round_had_changes: &AtomicBool,
+        dht: &DhtConn<ShortestPathTables>,
+    ) {
+        let mut batch = Vec::with_capacity(BATCH_SIZE);
+
+        let pool = rayon::ThreadPoolBuilder::new().build().unwrap();
+
+        pool.scope(|s| {
+            for node in exact_changed_nodes {
+                for edge in worker
+                    .graph()
+                    .search(&query::ForwardlinksQuery::new(*node))
+                    .unwrap_or_default()
+                {
+                    if edge.rel_flags.intersects(*SKIPPED_REL) {
+                        continue;
+                    }
+
+                    if changed_nodes.contains(edge.from) {
+                        batch.push(edge);
+                    }
+
+                    if batch.len() >= BATCH_SIZE {
+                        let update_batch = batch.clone();
+                        s.spawn(move |_| {
+                            Self::map_batch(
+                                &update_batch,
+                                new_changed_nodes,
+                                round_had_changes,
+                                dht,
+                            )
+                        });
+                        batch.clear();
+                    }
+                }
+            }
+        });
+
+        if !batch.is_empty() {
+            Self::map_batch(&batch, new_changed_nodes, round_had_changes, dht);
+        }
+    }
 }
 
 impl Mapper for ShortestPathMapper {
@@ -112,47 +198,41 @@ impl Mapper for ShortestPathMapper {
         match self {
             ShortestPathMapper::RelaxEdges => {
                 let round_had_changes = Arc::new(AtomicBool::new(false));
-                let pool = rayon::ThreadPoolBuilder::new().build().unwrap();
-
-                let new_changed_nodes = Arc::new(Mutex::new(U64BloomFilter::empty_from(
-                    &worker.changed_nodes().lock().unwrap(),
-                )));
-
-                pool.scope(|s| {
-                    let mut changed_nodes = worker.changed_nodes().lock().unwrap();
-                    changed_nodes.insert_u128(job.source.as_u128());
-
-                    let mut batch = Vec::with_capacity(BATCH_SIZE);
-
-                    for edge in worker.graph().page_edges() {
-                        if edge.rel_flags.intersects(*SKIPPED_REL) {
-                            continue;
-                        }
-
-                        if changed_nodes.contains_u128(edge.from.as_u128()) {
-                            batch.push(edge);
-                        }
-
-                        if batch.len() >= BATCH_SIZE {
-                            let update_batch = batch.clone();
-                            let update_new_changed_nodes = new_changed_nodes.clone();
-                            let update_round_had_changes = round_had_changes.clone();
-                            s.spawn(move |_| {
-                                Self::map_batch(
-                                    &update_batch,
-                                    &update_new_changed_nodes,
-                                    &update_round_had_changes,
-                                    dht,
-                                )
-                            });
-                            batch.clear();
-                        }
-                    }
 
-                    if !batch.is_empty() {
-                        Self::map_batch(&batch, &new_changed_nodes, &round_had_changes, dht);
+                let mut changed_nodes = worker.changed_nodes().lock().unwrap();
+                changed_nodes.add(job.source);
+
+                let new_changed_nodes =
+                    Arc::new(Mutex::new(UpdatedNodes::empty_from(&changed_nodes)));
+
+                match changed_nodes.kind() {
+                    UpdatedNodesKind::Exact => {
+                        let exact_changed_nodes: Vec<_> = changed_nodes
+                            .as_exact()
+                            .unwrap()
+                            .clone()
+                            .into_iter()
+                            .collect();
+
+                        Self::relax_exact_edges(
+                            worker,
+                            &changed_nodes,
+                            &exact_changed_nodes,
+                            &new_changed_nodes,
+                            &round_had_changes,
+                            dht,
+                        );
                     }
-                });
+                    UpdatedNodesKind::Sketch => {
+                        Self::relax_all_edges(
+                            worker,
+                            &changed_nodes,
+                            &new_changed_nodes,
+                            &round_had_changes,
+                            dht,
+                        );
+                    }
+                }
 
                 dht.next()
                     .changed_nodes
@@ -169,10 +249,10 @@ impl Mapper for ShortestPathMapper {
                 let all_changed_nodes: Vec<_> =
                     dht.next().changed_nodes.iter().map(|(_, v)| v).collect();
                 let mut changed_nodes =
-                    U64BloomFilter::empty_from(&worker.changed_nodes().lock().unwrap());
+                    UpdatedNodes::empty_from(&worker.changed_nodes().lock().unwrap());
 
-                for bloom in all_changed_nodes {
-                    changed_nodes.union(bloom.clone());
+                for other in &all_changed_nodes {
+                    changed_nodes = changed_nodes.union(other);
                 }
 
                 *worker.changed_nodes().lock().unwrap() = changed_nodes;

+ 3 - 2
crates/core/src/entrypoint/ampc/shortest_path/mod.rs

@@ -18,9 +18,10 @@
 
 pub mod coordinator;
 mod mapper;
+mod updated_nodes;
 pub mod worker;
 
-use bloom::U64BloomFilter;
+pub use updated_nodes::UpdatedNodes;
 
 use crate::distributed::member::ShardId;
 use crate::{
@@ -49,7 +50,7 @@ pub struct Meta {
 pub struct ShortestPathTables {
     distances: DefaultDhtTable<webgraph::NodeID, u64>,
     meta: DefaultDhtTable<(), Meta>,
-    changed_nodes: DefaultDhtTable<ShardId, U64BloomFilter>,
+    changed_nodes: DefaultDhtTable<ShardId, UpdatedNodes>,
 }
 
 impl_dht_tables!(ShortestPathTables, [distances, meta, changed_nodes]);

+ 165 - 0
crates/core/src/entrypoint/ampc/shortest_path/updated_nodes.rs

@@ -0,0 +1,165 @@
+// Stract is an open source web search engine.
+// Copyright (C) 2024 Stract ApS
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program.  If not, see <https://www.gnu.org/licenses/>
+
+use bloom::U64BloomFilter;
+use rustc_hash::FxHashSet;
+
+use crate::webgraph;
+
+const SKETCH_THRESHOLD: usize = 16_384;
+
+#[derive(
+    Debug, Clone, bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize, PartialEq,
+)]
+pub enum InnerUpdatedNodes {
+    Exact(FxHashSet<webgraph::NodeID>),
+    Sketch(U64BloomFilter),
+}
+
+impl Default for InnerUpdatedNodes {
+    fn default() -> Self {
+        InnerUpdatedNodes::Exact(FxHashSet::default())
+    }
+}
+
+impl InnerUpdatedNodes {
+    fn contains(&self, node: webgraph::NodeID) -> bool {
+        match self {
+            InnerUpdatedNodes::Exact(nodes) => nodes.contains(&node),
+            InnerUpdatedNodes::Sketch(sketch) => sketch.contains_u128(node.as_u128()),
+        }
+    }
+
+    fn union(&self, other: &Self, total_nodes: u64) -> Self {
+        match (self, other) {
+            (InnerUpdatedNodes::Exact(nodes), InnerUpdatedNodes::Exact(other_nodes)) => {
+                let mut new_nodes = nodes.clone();
+                new_nodes.extend(other_nodes);
+                if new_nodes.len() > SKETCH_THRESHOLD {
+                    let mut bloom = U64BloomFilter::new(total_nodes, 0.01);
+
+                    for node in nodes {
+                        bloom.insert_u128(node.as_u128());
+                    }
+
+                    InnerUpdatedNodes::Sketch(bloom)
+                } else {
+                    InnerUpdatedNodes::Exact(new_nodes)
+                }
+            }
+            (InnerUpdatedNodes::Sketch(sketch), InnerUpdatedNodes::Sketch(other_sketch)) => {
+                let mut new_sketch = sketch.clone();
+                new_sketch.union(other_sketch.clone());
+                InnerUpdatedNodes::Sketch(new_sketch)
+            }
+            (InnerUpdatedNodes::Exact(nodes), InnerUpdatedNodes::Sketch(sketch)) => {
+                let mut new_sketch = sketch.clone();
+
+                for node in nodes {
+                    new_sketch.insert_u128(node.as_u128());
+                }
+
+                InnerUpdatedNodes::Sketch(new_sketch)
+            }
+            (InnerUpdatedNodes::Sketch(sketch), InnerUpdatedNodes::Exact(nodes)) => {
+                let mut new_sketch = sketch.clone();
+
+                for node in nodes {
+                    new_sketch.insert_u128(node.as_u128());
+                }
+
+                InnerUpdatedNodes::Sketch(new_sketch)
+            }
+        }
+    }
+
+    fn add(&mut self, node: webgraph::NodeID, total_nodes: u64) {
+        match self {
+            InnerUpdatedNodes::Exact(nodes) => {
+                nodes.insert(node);
+
+                if nodes.len() > SKETCH_THRESHOLD {
+                    let mut bloom = U64BloomFilter::new(total_nodes, 0.01);
+
+                    for node in nodes.iter() {
+                        bloom.insert_u128(node.as_u128());
+                    }
+
+                    *self = InnerUpdatedNodes::Sketch(bloom);
+                }
+            }
+            InnerUpdatedNodes::Sketch(sketch) => sketch.insert_u128(node.as_u128()),
+        }
+    }
+}
+
+pub enum UpdatedNodesKind {
+    Exact,
+    Sketch,
+}
+
+#[derive(
+    Debug, Clone, bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize, PartialEq,
+)]
+pub struct UpdatedNodes {
+    inner: InnerUpdatedNodes,
+    total_nodes: u64,
+}
+
+impl UpdatedNodes {
+    pub fn new(total_nodes: u64) -> Self {
+        Self {
+            inner: InnerUpdatedNodes::default(),
+            total_nodes,
+        }
+    }
+
+    pub fn kind(&self) -> UpdatedNodesKind {
+        match self.inner {
+            InnerUpdatedNodes::Exact(_) => UpdatedNodesKind::Exact,
+            InnerUpdatedNodes::Sketch(_) => UpdatedNodesKind::Sketch,
+        }
+    }
+
+    pub fn add(&mut self, node: webgraph::NodeID) {
+        self.inner.add(node, self.total_nodes);
+    }
+
+    pub fn union(&self, other: &Self) -> Self {
+        Self {
+            inner: self.inner.union(&other.inner, self.total_nodes),
+            total_nodes: self.total_nodes,
+        }
+    }
+
+    pub fn contains(&self, node: webgraph::NodeID) -> bool {
+        self.inner.contains(node)
+    }
+
+    pub fn empty_from(other: &Self) -> Self {
+        Self {
+            inner: InnerUpdatedNodes::default(),
+            total_nodes: other.total_nodes,
+        }
+    }
+
+    pub fn as_exact(&self) -> Option<&FxHashSet<webgraph::NodeID>> {
+        match &self.inner {
+            InnerUpdatedNodes::Exact(nodes) => Some(nodes),
+            InnerUpdatedNodes::Sketch(_) => None,
+        }
+    }
+}

+ 6 - 9
crates/core/src/entrypoint/ampc/shortest_path/worker.rs

@@ -30,14 +30,13 @@ use std::{
     sync::{Arc, Mutex},
 };
 
-use super::{impl_worker, Message, RemoteWorker, ShortestPathJob};
-use bloom::U64BloomFilter;
+use super::{impl_worker, updated_nodes::UpdatedNodes, Message, RemoteWorker, ShortestPathJob};
 
 #[derive(Clone)]
 pub struct ShortestPathWorker {
     shard: ShardId,
     graph: Arc<Webgraph>,
-    changed_nodes: Arc<Mutex<U64BloomFilter>>,
+    changed_nodes: Arc<Mutex<UpdatedNodes>>,
     nodes_sketch: HyperLogLog<4096>,
 }
 
@@ -48,14 +47,12 @@ impl ShortestPathWorker {
         for node in graph.page_nodes() {
             nodes_sketch.add_u128(node.as_u128());
         }
+        let num_nodes = nodes_sketch.size() as u64;
 
         Self {
             graph: Arc::new(graph),
             shard,
-            changed_nodes: Arc::new(Mutex::new(U64BloomFilter::new(
-                nodes_sketch.size() as u64,
-                0.01,
-            ))),
+            changed_nodes: Arc::new(Mutex::new(UpdatedNodes::new(num_nodes))),
             nodes_sketch,
         }
     }
@@ -68,7 +65,7 @@ impl ShortestPathWorker {
         self.shard
     }
 
-    pub fn changed_nodes(&self) -> &Arc<Mutex<U64BloomFilter>> {
+    pub fn changed_nodes(&self) -> &Arc<Mutex<UpdatedNodes>> {
         &self.changed_nodes
     }
 
@@ -78,7 +75,7 @@ impl ShortestPathWorker {
 
     pub fn update_changed_nodes_precision(&self, num_nodes: u64) {
         let mut changed_nodes = self.changed_nodes().lock().unwrap();
-        *changed_nodes = U64BloomFilter::new(num_nodes, 0.01);
+        *changed_nodes = UpdatedNodes::new(num_nodes);
     }
 }