瀏覽代碼

SharedBuffer: Fix a denial of service

It's a very bad idea to increment the refcount on behalf of another
process. That process may (for either benign or evil reasons) not
reference the SharedBuffer, and then we'll be stuck with loads of
SharedBuffers until we OOM.

Instead, increment the refcount when the buffer is mapped. That way, a
buffer is only kept if *someone* has explicitly requested it via
get_shared_buffer.

Fixes #341
Robin Burchell 6 年之前
父節點
當前提交
2d4d465206
共有 3 個文件被更改,包括 13 次插入10 次删除
  1. 3 3
      Kernel/Process.cpp
  2. 6 4
      Kernel/SharedBuffer.cpp
  3. 4 3
      Kernel/SharedBuffer.h

+ 3 - 3
Kernel/Process.cpp

@@ -2408,7 +2408,7 @@ int Process::sys$create_shared_buffer(int size, void** buffer)
     int shared_buffer_id = ++s_next_shared_buffer_id;
     int shared_buffer_id = ++s_next_shared_buffer_id;
     auto shared_buffer = make<SharedBuffer>(shared_buffer_id, size);
     auto shared_buffer = make<SharedBuffer>(shared_buffer_id, size);
     shared_buffer->share_with(m_pid);
     shared_buffer->share_with(m_pid);
-    *buffer = shared_buffer->get_address(*this);
+    *buffer = shared_buffer->ref_for_process_and_get_address(*this);
     ASSERT((int)shared_buffer->size() >= size);
     ASSERT((int)shared_buffer->size() >= size);
 #ifdef SHARED_BUFFER_DEBUG
 #ifdef SHARED_BUFFER_DEBUG
     kprintf("%s(%u): Created shared buffer %d @ %p (%u bytes, vmo is %u)\n", name().characters(), pid(), shared_buffer_id, *buffer, size, shared_buffer->size());
     kprintf("%s(%u): Created shared buffer %d @ %p (%u bytes, vmo is %u)\n", name().characters(), pid(), shared_buffer_id, *buffer, size, shared_buffer->size());
@@ -2447,7 +2447,7 @@ int Process::sys$release_shared_buffer(int shared_buffer_id)
 #ifdef SHARED_BUFFER_DEBUG
 #ifdef SHARED_BUFFER_DEBUG
     kprintf("%s(%u): Releasing shared buffer %d, buffer count: %u\n", name().characters(), pid(), shared_buffer_id, shared_buffers().resource().size());
     kprintf("%s(%u): Releasing shared buffer %d, buffer count: %u\n", name().characters(), pid(), shared_buffer_id, shared_buffers().resource().size());
 #endif
 #endif
-    shared_buffer.release(*this);
+    shared_buffer.deref_for_process(*this);
     return 0;
     return 0;
 }
 }
 
 
@@ -2463,7 +2463,7 @@ void* Process::sys$get_shared_buffer(int shared_buffer_id)
 #ifdef SHARED_BUFFER_DEBUG
 #ifdef SHARED_BUFFER_DEBUG
     kprintf("%s(%u): Retaining shared buffer %d, buffer count: %u\n", name().characters(), pid(), shared_buffer_id, shared_buffers().resource().size());
     kprintf("%s(%u): Retaining shared buffer %d, buffer count: %u\n", name().characters(), pid(), shared_buffer_id, shared_buffers().resource().size());
 #endif
 #endif
-    return shared_buffer.get_address(*this);
+    return shared_buffer.ref_for_process_and_get_address(*this);
 }
 }
 
 
 int Process::sys$seal_shared_buffer(int shared_buffer_id)
 int Process::sys$seal_shared_buffer(int shared_buffer_id)

+ 6 - 4
Kernel/SharedBuffer.cpp

@@ -21,12 +21,14 @@ bool SharedBuffer::is_shared_with(pid_t peer_pid)
     return false;
     return false;
 }
 }
 
 
-void* SharedBuffer::get_address(Process& process)
+void* SharedBuffer::ref_for_process_and_get_address(Process& process)
 {
 {
     LOCKER(shared_buffers().lock());
     LOCKER(shared_buffers().lock());
     ASSERT(is_shared_with(process.pid()));
     ASSERT(is_shared_with(process.pid()));
     for (auto& ref : m_refs) {
     for (auto& ref : m_refs) {
         if (ref.pid == process.pid()) {
         if (ref.pid == process.pid()) {
+            ref.count++;
+            m_total_refs++;
             if (ref.region == nullptr) {
             if (ref.region == nullptr) {
                 ref.region = process.allocate_region_with_vmo(VirtualAddress(), size(), m_vmo, 0, "SharedBuffer", PROT_READ | (m_writable ? PROT_WRITE : 0));
                 ref.region = process.allocate_region_with_vmo(VirtualAddress(), size(), m_vmo, 0, "SharedBuffer", PROT_READ | (m_writable ? PROT_WRITE : 0));
                 ref.region->set_shared(true);
                 ref.region->set_shared(true);
@@ -42,7 +44,7 @@ void SharedBuffer::share_with(pid_t peer_pid)
     LOCKER(shared_buffers().lock());
     LOCKER(shared_buffers().lock());
     for (auto& ref : m_refs) {
     for (auto& ref : m_refs) {
         if (ref.pid == peer_pid) {
         if (ref.pid == peer_pid) {
-            ref.count++;
+            // don't increment the reference count yet; let them get_shared_buffer it first.
             return;
             return;
         }
         }
     }
     }
@@ -50,7 +52,7 @@ void SharedBuffer::share_with(pid_t peer_pid)
     m_refs.append(Reference(peer_pid));
     m_refs.append(Reference(peer_pid));
 }
 }
 
 
-void SharedBuffer::release(Process& process)
+void SharedBuffer::deref_for_process(Process& process)
 {
 {
     LOCKER(shared_buffers().lock());
     LOCKER(shared_buffers().lock());
     for (int i = 0; i < m_refs.size(); ++i) {
     for (int i = 0; i < m_refs.size(); ++i) {
@@ -94,7 +96,7 @@ void SharedBuffer::disown(pid_t pid)
 void SharedBuffer::destroy_if_unused()
 void SharedBuffer::destroy_if_unused()
 {
 {
     LOCKER(shared_buffers().lock());
     LOCKER(shared_buffers().lock());
-    if (m_refs.size() == 0) {
+    if (m_total_refs == 0) {
 #ifdef SHARED_BUFFER_DEBUG
 #ifdef SHARED_BUFFER_DEBUG
         kprintf("Destroying unused SharedBuffer{%p} id: %d\n", this, m_shared_buffer_id);
         kprintf("Destroying unused SharedBuffer{%p} id: %d\n", this, m_shared_buffer_id);
 #endif
 #endif

+ 4 - 3
Kernel/SharedBuffer.h

@@ -12,7 +12,7 @@ private:
         }
         }
 
 
         pid_t pid;
         pid_t pid;
-        unsigned count { 1 };
+        unsigned count { 0 };
         Region* region { nullptr };
         Region* region { nullptr };
     };
     };
 public:
 public:
@@ -33,9 +33,9 @@ public:
     }
     }
 
 
     bool is_shared_with(pid_t peer_pid);
     bool is_shared_with(pid_t peer_pid);
-    void* get_address(Process& process);
+    void* ref_for_process_and_get_address(Process& process);
     void share_with(pid_t peer_pid);
     void share_with(pid_t peer_pid);
-    void release(Process& process);
+    void deref_for_process(Process& process);
     void disown(pid_t pid);
     void disown(pid_t pid);
     size_t size() const { return m_vmo->size(); }
     size_t size() const { return m_vmo->size(); }
     void destroy_if_unused();
     void destroy_if_unused();
@@ -45,6 +45,7 @@ public:
     bool m_writable { true };
     bool m_writable { true };
     NonnullRefPtr<VMObject> m_vmo;
     NonnullRefPtr<VMObject> m_vmo;
     Vector<Reference, 2> m_refs;
     Vector<Reference, 2> m_refs;
+    unsigned m_total_refs { 0 };
 };
 };
 
 
 Lockable<HashMap<int, OwnPtr<SharedBuffer>>>& shared_buffers();
 Lockable<HashMap<int, OwnPtr<SharedBuffer>>>& shared_buffers();