Prechádzať zdrojové kódy

Implement first WebSocket/OT server

Eric Zhang 4 rokov pred
rodič
commit
960bdb315a

+ 7 - 0
Cargo.lock

@@ -9,6 +9,12 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "anyhow"
+version = "1.0.40"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "28b2cd92db5cbd74e8e5028f7e27dd7aa3090e89e4f2a197cc7c8dfb69c7063b"
+
 [[package]]
 name = "atty"
 version = "0.2.14"
@@ -836,6 +842,7 @@ dependencies = [
 name = "rustpad-server"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
  "dotenv",
  "futures",
  "log",

+ 1 - 0
rustpad-server/Cargo.toml

@@ -5,6 +5,7 @@ authors = ["Eric Zhang <ekzhang1@gmail.com>"]
 edition = "2018"
 
 [dependencies]
+anyhow = "1.0.40"
 dotenv = "0.15.0"
 futures = "0.3.15"
 log = "0.4.14"

+ 76 - 8
rustpad-server/src/lib.rs

@@ -30,29 +30,97 @@ fn backend() -> BoxedFilter<(impl Reply,)> {
     let socket = warp::path("socket")
         .and(warp::path::end())
         .and(warp::ws())
-        .and(rustpad)
+        .and(rustpad.clone())
         .map(|ws: Ws, rustpad: Arc<Rustpad>| {
             ws.on_upgrade(move |socket| async move { rustpad.on_connection(socket).await })
         });
 
-    socket.boxed()
+    let text = warp::path("text")
+        .and(warp::path::end())
+        .and(rustpad.clone())
+        .map(|rustpad: Arc<Rustpad>| rustpad.text());
+
+    socket.or(text).boxed()
 }
 
 #[cfg(test)]
 mod tests {
+    use log::info;
+    use operational_transform::OperationSeq;
+    use serde_json::{json, Value};
+
     use super::*;
 
     #[tokio::test]
-    async fn test_single_message() {
+    async fn test_single_operation() {
+        pretty_env_logger::try_init().ok();
         let filter = backend();
+
+        let resp = warp::test::request().path("/text").reply(&filter).await;
+        assert_eq!(resp.status(), 200);
+        assert_eq!(resp.body(), "");
+
         let mut client = warp::test::ws()
             .path("/socket")
-            .handshake(filter)
+            .handshake(filter.clone())
             .await
             .expect("handshake");
-        client.send_text("hello world").await;
-        let msg = client.recv().await.expect("recv");
-        let msg = msg.to_str().expect("string");
-        assert_eq!(msg, "[[0,\"hello world\"]]");
+        let msg = client.recv().await.unwrap();
+        let msg = msg.to_str().unwrap();
+        assert_eq!(msg, r#"{"Identity":0}"#);
+
+        let mut operation = OperationSeq::default();
+        operation.insert("hello");
+        let serialized = format!(
+            r#"{{"Edit": {{"revision": 0, "operation": {}}}}}"#,
+            serde_json::to_string(&operation).unwrap(),
+        );
+        info!("sending ClientMsg {}", serialized);
+        client.send_text(serialized).await;
+
+        let msg = client.recv().await.unwrap();
+        let msg = msg.to_str().unwrap();
+        let msg: Value = serde_json::from_str(&msg).unwrap();
+        assert_eq!(
+            msg,
+            json!({
+                "History": {
+                    "start": 0,
+                    "operations": [
+                        { "id": 0, "operation": ["hello"] }
+                    ]
+                }
+            })
+        );
+
+        let resp = warp::test::request().path("/text").reply(&filter).await;
+        assert_eq!(resp.status(), 200);
+        assert_eq!(resp.body(), "hello");
+    }
+
+    #[tokio::test]
+    async fn test_invalid_operation() {
+        pretty_env_logger::try_init().ok();
+        let filter = backend();
+
+        let mut client = warp::test::ws()
+            .path("/socket")
+            .handshake(filter.clone())
+            .await
+            .expect("handshake");
+        let msg = client.recv().await.unwrap();
+        let msg = msg.to_str().unwrap();
+        assert_eq!(msg, r#"{"Identity":0}"#);
+
+        let mut operation = OperationSeq::default();
+        operation.insert("hello");
+        let serialized = format!(
+            r#"{{"Edit": {{"revision": 1, "operation": {}}}}}"#,
+            serde_json::to_string(&operation).unwrap(),
+        );
+        info!("sending ClientMsg {}", serialized);
+        client.send_text(serialized).await;
+
+        client.recv_closed().await.expect("socket should be closed");
     }
 }

+ 94 - 48
rustpad-server/src/rustpad.rs

@@ -3,10 +3,11 @@
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::time::Duration;
 
+use anyhow::{bail, Context, Result};
 use futures::prelude::*;
-use log::{error, info};
+use log::{info, warn};
 use operational_transform::OperationSeq;
-use parking_lot::RwLock;
+use parking_lot::{RwLock, RwLockUpgradableReadGuard};
 use serde::{Deserialize, Serialize};
 use tokio::{sync::Notify, time};
 use warp::ws::{Message, WebSocket};
@@ -22,24 +23,45 @@ pub struct Rustpad {
 /// Shared state involving multiple users, protected by a lock
 #[derive(Default)]
 struct State {
-    messages: Vec<(u64, String)>,
+    operations: Vec<UserOperation>,
+    text: String,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+struct UserOperation {
+    id: u64,
+    operation: OperationSeq,
 }
 
 /// A message received from the client over WebSocket
 #[derive(Clone, Debug, Serialize, Deserialize)]
 enum ClientMsg {
-    Edit { revision: usize, op: OperationSeq },
+    /// Represents a sequence of local edits from the user
+    Edit {
+        revision: usize,
+        operation: OperationSeq,
+    },
 }
 
 /// A message sent to the client over WebSocket
 #[derive(Clone, Debug, Serialize, Deserialize)]
 enum ServerMsg {
+    /// Informs the client of their unique socket ID
+    Identity(u64),
+    /// Broadcasts text operations to all clients
     History {
-        revision: usize,
-        ops: Vec<OperationSeq>,
+        start: usize,
+        operations: Vec<UserOperation>,
     },
 }
 
+impl From<ServerMsg> for Message {
+    fn from(msg: ServerMsg) -> Self {
+        let serialized = serde_json::to_string(&msg).expect("failed serialize");
+        Message::text(serialized)
+    }
+}
+
 impl Rustpad {
     /// Construct a new, empty Rustpad object
     pub fn new() -> Self {
@@ -47,21 +69,35 @@ impl Rustpad {
     }
 
     /// Handle a connection from a WebSocket
-    pub async fn on_connection(&self, mut socket: WebSocket) {
+    pub async fn on_connection(&self, socket: WebSocket) {
         let id = self.count.fetch_add(1, Ordering::Relaxed);
         info!("connection! id = {}", id);
+        if let Err(e) = self.handle_connection(id, socket).await {
+            warn!("connection terminated early: {}", e);
+        }
+        info!("disconnection, id = {}", id);
+    }
+
+    /// Returns a snapshot of the latest text
+    pub fn text(&self) -> String {
+        let state = self.state.read();
+        state.text.clone()
+    }
+
+    /// Returns the current revision
+    pub fn revision(&self) -> usize {
+        let state = self.state.read();
+        state.operations.len()
+    }
+
+    async fn handle_connection(&self, id: u64, mut socket: WebSocket) -> Result<()> {
+        socket.send(ServerMsg::Identity(id).into()).await?;
 
         let mut revision: usize = 0;
 
         loop {
-            if self.num_messages() > revision {
-                match self.send_messages(revision, &mut socket).await {
-                    Ok(new_revision) => revision = new_revision,
-                    Err(e) => {
-                        error!("websocket error: {}", e);
-                        break;
-                    }
-                }
+            if self.revision() > revision {
+                revision = self.send_history(revision, &mut socket).await?
             }
 
             let sleep = time::sleep(Duration::from_millis(500));
@@ -72,56 +108,66 @@ impl Rustpad {
                 result = socket.next() => {
                     match result {
                         None => break,
-                        Some(Ok(message)) => {
-                            self.handle_message(id, message).await
-                        }
-                        Some(Err(e)) => {
-                            error!("websocket error: {}", e);
-                            break;
+                        Some(message) => {
+                            self.handle_message(id, message?).await?;
                         }
                     }
                 }
             }
         }
 
-        info!("disconnection, id = {}", id);
-    }
-
-    fn num_messages(&self) -> usize {
-        let state = self.state.read();
-        state.messages.len()
+        Ok(())
     }
 
-    async fn send_messages(
-        &self,
-        revision: usize,
-        socket: &mut WebSocket,
-    ) -> Result<usize, warp::Error> {
-        let messages = {
+    async fn send_history(&self, start: usize, socket: &mut WebSocket) -> Result<usize> {
+        let operations = {
             let state = self.state.read();
-            let len = state.messages.len();
-            if revision < len {
-                state.messages[revision..].to_owned()
+            let len = state.operations.len();
+            if start < len {
+                state.operations[start..].to_owned()
             } else {
                 Vec::new()
             }
         };
-        if !messages.is_empty() {
-            let serialized = serde_json::to_string(&messages)
-                .expect("serde serialization failed for messages vec");
-            socket.send(Message::text(&serialized)).await?;
+        let num_ops = operations.len();
+        if num_ops > 0 {
+            let msg = ServerMsg::History { start, operations };
+            socket.send(msg.into()).await?;
         }
-        Ok(revision + messages.len())
+        Ok(start + num_ops)
     }
 
-    async fn handle_message(&self, id: u64, message: Message) {
-        let text = match message.to_str() {
-            Ok(text) => String::from(text),
-            Err(()) => return, // Ignore non-text messages
+    async fn handle_message(&self, id: u64, message: Message) -> Result<()> {
+        let msg: ClientMsg = match message.to_str() {
+            Ok(text) => serde_json::from_str(text).context("failed to deserialize message")?,
+            Err(()) => return Ok(()), // Ignore non-text messages
         };
+        match msg {
+            ClientMsg::Edit {
+                revision,
+                operation,
+            } => {
+                self.apply_edit(id, revision, operation)
+                    .context("invalid edit operation")?;
+                self.notify.notify_waiters();
+            }
+        }
+        Ok(())
+    }
 
-        let mut state = self.state.write();
-        state.messages.push((id, text));
-        self.notify.notify_waiters();
+    fn apply_edit(&self, id: u64, revision: usize, mut operation: OperationSeq) -> Result<()> {
+        let state = self.state.upgradable_read();
+        let len = state.operations.len();
+        if revision > len {
+            bail!("got revision {}, but current is {}", revision, len);
+        }
+        for history_op in &state.operations[revision..] {
+            operation = operation.transform(&history_op.operation)?.0;
+        }
+        let new_text = operation.apply(&state.text)?;
+        let mut state = RwLockUpgradableReadGuard::upgrade(state);
+        state.operations.push(UserOperation { id, operation });
+        state.text = new_text;
+        Ok(())
     }
 }