Browse Source

Kernel: Fix subtle race condition in sys$write implementation

There is a slight race condition in our implementation of write().
We call File::can_write() before attempting to write to it (blocking if
it returns false). If it returns true, we assume that we can write to
the file, and our code assumes that File::write() cannot possibly fail
by being blocked. There is, however, the rare case where another process
writes to the file and prevents further writes in between the call to
Files::can_write() and File::write() in the first process. This would
result in the first process calling File::write() when it cannot be
written to.

We fix this by adding a mechanism for File::can_write() to signal that
it was blocked, making it the responsibilty of File::write() to check
whether it can write and then finally making sys$write() check if the
write failed due to it being blocked.
Sahan Fernando 4 năm trước cách đây
mục cha
commit
d0f314b23c
3 tập tin đã thay đổi với 15 bổ sung14 xóa
  1. 5 3
      Kernel/Devices/SerialDevice.cpp
  2. 1 0
      Kernel/Devices/SerialDevice.h
  3. 9 11
      Kernel/Syscalls/write.cpp

+ 5 - 3
Kernel/Devices/SerialDevice.cpp

@@ -31,6 +31,7 @@ KResultOr<size_t> SerialDevice::read(FileDescription&, u64, UserOrKernelBuffer&
     if (!size)
         return 0;
 
+    ScopedSpinLock lock(m_serial_lock);
     if (!(get_line_status() & DataReady))
         return 0;
 
@@ -46,13 +47,14 @@ bool SerialDevice::can_write(const FileDescription&, size_t) const
     return (get_line_status() & EmptyTransmitterHoldingRegister) != 0;
 }
 
-KResultOr<size_t> SerialDevice::write(FileDescription&, u64, const UserOrKernelBuffer& buffer, size_t size)
+KResultOr<size_t> SerialDevice::write(FileDescription& description, u64, const UserOrKernelBuffer& buffer, size_t size)
 {
     if (!size)
         return 0;
 
-    if (!(get_line_status() & EmptyTransmitterHoldingRegister))
-        return 0;
+    ScopedSpinLock lock(m_serial_lock);
+    if (!can_write(description, size))
+        return EAGAIN;
 
     return buffer.read_buffered<128>(size, [&](u8 const* data, size_t data_size) {
         for (size_t i = 0; i < data_size; i++)

+ 1 - 0
Kernel/Devices/SerialDevice.h

@@ -135,6 +135,7 @@ private:
     bool m_break_enable { false };
     u8 m_modem_control { 0 };
     bool m_last_put_char_was_carriage_return { false };
+    SpinLock<u8> m_serial_lock;
 };
 
 }

+ 9 - 11
Kernel/Syscalls/write.cpp

@@ -60,10 +60,6 @@ KResultOr<ssize_t> Process::sys$writev(int fd, Userspace<const struct iovec*> io
 KResultOr<ssize_t> Process::do_write(FileDescription& description, const UserOrKernelBuffer& data, size_t data_size)
 {
     ssize_t total_nwritten = 0;
-    if (!description.is_blocking()) {
-        if (!description.can_write())
-            return EAGAIN;
-    }
 
     if (description.should_append() && description.file().is_seekable()) {
         auto seek_result = description.seek(0, SEEK_END);
@@ -72,11 +68,12 @@ KResultOr<ssize_t> Process::do_write(FileDescription& description, const UserOrK
     }
 
     while ((size_t)total_nwritten < data_size) {
-        if (!description.can_write()) {
+        while (!description.can_write()) {
             if (!description.is_blocking()) {
-                // Short write: We can no longer write to this non-blocking description.
-                VERIFY(total_nwritten > 0);
-                return total_nwritten;
+                if (total_nwritten > 0)
+                    return total_nwritten;
+                else
+                    return EAGAIN;
             }
             auto unblock_flags = Thread::FileBlocker::BlockFlags::None;
             if (Thread::current()->block<Thread::WriteBlocker>({}, description, unblock_flags).was_interrupted()) {
@@ -87,12 +84,13 @@ KResultOr<ssize_t> Process::do_write(FileDescription& description, const UserOrK
         }
         auto nwritten_or_error = description.write(data.offset(total_nwritten), data_size - total_nwritten);
         if (nwritten_or_error.is_error()) {
-            if (total_nwritten)
+            if (total_nwritten > 0)
                 return total_nwritten;
+            if (nwritten_or_error.error() == EAGAIN)
+                continue;
             return nwritten_or_error.error();
         }
-        if (nwritten_or_error.value() == 0)
-            break;
+        VERIFY(nwritten_or_error.value() > 0);
         total_nwritten += nwritten_or_error.value();
     }
     return total_nwritten;