Browse Source

Kernel: Fix a few Thread::block related races

We need to have a Thread lock to protect threading related
operations, such as Thread::m_blocker which is used in
Thread::block.

Also, if a Thread::Blocker indicates that it should be
unblocking immediately, don't actually block the Thread
and instead return immediately in Thread::block.
Tom 5 years ago
parent
commit
c813bb7355

+ 1 - 1
Kernel/FileSystem/Plan9FileSystem.cpp

@@ -498,7 +498,7 @@ KResult Plan9FS::read_and_dispatch_one_message()
     return KSuccess;
 }
 
-bool Plan9FS::Blocker::should_unblock(Thread&, time_t, long)
+bool Plan9FS::Blocker::should_unblock(Thread&)
 {
     if (m_completion.completed)
         return true;

+ 1 - 1
Kernel/FileSystem/Plan9FileSystem.h

@@ -81,7 +81,7 @@ private:
             : m_completion(completion)
         {
         }
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Waiting"; }
 
     private:

+ 5 - 2
Kernel/Ptrace.cpp

@@ -63,8 +63,11 @@ KResultOr<u32> handle_syscall(const Kernel::Syscall::SC_ptrace_params& params, P
             return KResult(-EBUSY);
         }
         peer->start_tracing_from(caller.pid());
-        if (peer->state() != Thread::State::Stopped && !(peer->has_blocker() && peer->blocker().is_reason_signal()))
-            peer->send_signal(SIGSTOP, &caller);
+        if (peer->state() != Thread::State::Stopped) {
+            ScopedSpinLock lock(peer->get_lock());
+            if (!(peer->has_blocker() && peer->blocker().is_reason_signal()))
+                peer->send_signal(SIGSTOP, &caller);
+        }
         return KSuccess;
     }
 

+ 27 - 22
Kernel/Scheduler.cpp

@@ -62,19 +62,6 @@ void Scheduler::init_thread(Thread& thread)
     g_scheduler_data->m_nonrunnable_threads.append(thread);
 }
 
-void Scheduler::update_state_for_thread(Thread& thread)
-{
-    ASSERT_INTERRUPTS_DISABLED();
-    ASSERT(g_scheduler_data);
-    ASSERT(g_scheduler_lock.own_lock());
-    auto& list = g_scheduler_data->thread_list_for_state(thread.state());
-
-    if (list.contains(thread))
-        return;
-
-    list.append(thread);
-}
-
 static u32 time_slice_for(const Thread& thread)
 {
     // One time slice unit == 1ms
@@ -104,7 +91,7 @@ Thread::JoinBlocker::JoinBlocker(Thread& joinee, void*& joinee_exit_value)
     current_thread->m_joinee = &joinee;
 }
 
-bool Thread::JoinBlocker::should_unblock(Thread& joiner, time_t, long)
+bool Thread::JoinBlocker::should_unblock(Thread& joiner)
 {
     return !joiner.m_joinee;
 }
@@ -124,7 +111,7 @@ Thread::AcceptBlocker::AcceptBlocker(const FileDescription& description)
 {
 }
 
-bool Thread::AcceptBlocker::should_unblock(Thread&, time_t, long)
+bool Thread::AcceptBlocker::should_unblock(Thread&)
 {
     auto& socket = *blocked_description().socket();
     return socket.can_accept();
@@ -135,7 +122,7 @@ Thread::ConnectBlocker::ConnectBlocker(const FileDescription& description)
 {
 }
 
-bool Thread::ConnectBlocker::should_unblock(Thread&, time_t, long)
+bool Thread::ConnectBlocker::should_unblock(Thread&)
 {
     auto& socket = *blocked_description().socket();
     return socket.setup_state() == Socket::SetupState::Completed;
@@ -157,12 +144,17 @@ Thread::WriteBlocker::WriteBlocker(const FileDescription& description)
     }
 }
 
-bool Thread::WriteBlocker::should_unblock(Thread&, time_t now_sec, long now_usec)
+bool Thread::WriteBlocker::should_unblock(Thread& thread, time_t now_sec, long now_usec)
 {
     if (m_deadline.has_value()) {
         bool timed_out = now_sec > m_deadline.value().tv_sec || (now_sec == m_deadline.value().tv_sec && now_usec >= m_deadline.value().tv_usec);
         return timed_out || blocked_description().can_write();
     }
+    return should_unblock(thread);
+}
+
+bool Thread::WriteBlocker::should_unblock(Thread&)
+{
     return blocked_description().can_write();
 }
 
@@ -182,12 +174,17 @@ Thread::ReadBlocker::ReadBlocker(const FileDescription& description)
     }
 }
 
-bool Thread::ReadBlocker::should_unblock(Thread&, time_t now_sec, long now_usec)
+bool Thread::ReadBlocker::should_unblock(Thread& thread, time_t now_sec, long now_usec)
 {
     if (m_deadline.has_value()) {
         bool timed_out = now_sec > m_deadline.value().tv_sec || (now_sec == m_deadline.value().tv_sec && now_usec >= m_deadline.value().tv_usec);
         return timed_out || blocked_description().can_read();
     }
+    return should_unblock(thread);
+}
+
+bool Thread::ReadBlocker::should_unblock(Thread&)
+{
     return blocked_description().can_read();
 }
 
@@ -198,7 +195,7 @@ Thread::ConditionBlocker::ConditionBlocker(const char* state_string, Function<bo
     ASSERT(m_block_until_condition);
 }
 
-bool Thread::ConditionBlocker::should_unblock(Thread&, time_t, long)
+bool Thread::ConditionBlocker::should_unblock(Thread&)
 {
     return m_block_until_condition();
 }
@@ -208,7 +205,7 @@ Thread::SleepBlocker::SleepBlocker(u64 wakeup_time)
 {
 }
 
-bool Thread::SleepBlocker::should_unblock(Thread&, time_t, long)
+bool Thread::SleepBlocker::should_unblock(Thread&)
 {
     return m_wakeup_time <= g_uptime;
 }
@@ -228,7 +225,11 @@ bool Thread::SelectBlocker::should_unblock(Thread& thread, time_t now_sec, long
         if (now_sec > m_select_timeout.tv_sec || (now_sec == m_select_timeout.tv_sec && now_usec * 1000 >= m_select_timeout.tv_nsec))
             return true;
     }
+    return should_unblock(thread);
+}
 
+bool Thread::SelectBlocker::should_unblock(Thread& thread)
+{
     auto& process = thread.process();
     for (int fd : m_select_read_fds) {
         if (!process.m_fds[fd])
@@ -252,7 +253,7 @@ Thread::WaitBlocker::WaitBlocker(int wait_options, pid_t& waitee_pid)
 {
 }
 
-bool Thread::WaitBlocker::should_unblock(Thread& thread, time_t, long)
+bool Thread::WaitBlocker::should_unblock(Thread& thread)
 {
     bool should_unblock = m_wait_options & WNOHANG;
     if (m_waitee_pid != -1) {
@@ -294,7 +295,7 @@ Thread::SemiPermanentBlocker::SemiPermanentBlocker(Reason reason)
 {
 }
 
-bool Thread::SemiPermanentBlocker::should_unblock(Thread&, time_t, long)
+bool Thread::SemiPermanentBlocker::should_unblock(Thread&)
 {
     // someone else has to unblock us
     return false;
@@ -304,6 +305,7 @@ bool Thread::SemiPermanentBlocker::should_unblock(Thread&, time_t, long)
 // Make a decision as to whether to unblock them or not.
 void Thread::consider_unblock(time_t now_sec, long now_usec)
 {
+    ScopedSpinLock lock(m_lock);
     switch (state()) {
     case Thread::Invalid:
     case Thread::Runnable:
@@ -403,6 +405,7 @@ bool Scheduler::pick_next()
 
     // Dispatch any pending signals.
     Thread::for_each_living([&](Thread& thread) -> IterationDecision {
+        ScopedSpinLock lock(thread.get_lock());
         if (!thread.has_unmasked_pending_signals())
             return IterationDecision::Continue;
         // NOTE: dispatch_one_pending_signal() may unblock the process.
@@ -427,6 +430,8 @@ bool Scheduler::pick_next()
             dbg() << "  " << String::format("%-12s", thread.state_string()) << " " << thread << " @ " << String::format("%w", thread.tss().cs) << ":" << String::format("%x", thread.tss().eip) << " Reason: " << (thread.wait_reason() ? thread.wait_reason() : "none");
         else if (thread.state() == Thread::Dying)
             dbg() << "  " << String::format("%-12s", thread.state_string()) << " " << thread << " @ " << String::format("%w", thread.tss().cs) << ":" << String::format("%x", thread.tss().eip) << " Finalizable: " << thread.is_finalizable();
+        else
+            dbg() << "  " << String::format("%-12s", thread.state_string()) << " " << thread << " @ " << String::format("%w", thread.tss().cs) << ":" << String::format("%x", thread.tss().eip);
         return IterationDecision::Continue;
     });
 

+ 0 - 1
Kernel/Scheduler.h

@@ -78,7 +78,6 @@ public:
     static inline IterationDecision for_each_nonrunnable(Callback);
 
     static void init_thread(Thread& thread);
-    static void update_state_for_thread(Thread& thread);
 };
 
 }

+ 30 - 4
Kernel/Thread.cpp

@@ -114,6 +114,7 @@ Thread::~Thread()
 
 void Thread::unblock()
 {
+    ASSERT(m_lock.own_lock());
     m_blocker = nullptr;
     if (Thread::current() == this) {
         if (m_should_die)
@@ -144,6 +145,7 @@ void Thread::set_should_die()
     m_should_die = true;
 
     if (is_blocked()) {
+        ScopedSpinLock lock(m_lock);
         ASSERT(m_blocker != nullptr);
         // We're blocked in the kernel.
         m_blocker->set_interrupted_by_death();
@@ -264,6 +266,7 @@ void Thread::finalize()
     set_state(Thread::State::Dead);
 
     if (m_joiner) {
+        ScopedSpinLock lock(m_joiner->m_lock);
         ASSERT(m_joiner->m_joinee == this);
         static_cast<JoinBlocker*>(m_joiner->m_blocker)->set_joinee_exit_value(m_exit_value);
         static_cast<JoinBlocker*>(m_joiner->m_blocker)->set_interrupted_by_death();
@@ -468,9 +471,11 @@ ShouldUnblockThread Thread::dispatch_signal(u8 signal)
         set_state(m_stop_state);
         m_stop_state = State::Invalid;
         // make sure SemiPermanentBlocker is unblocked
-        if (m_state != Thread::Runnable && m_state != Thread::Running
-            && m_blocker && m_blocker->is_reason_signal())
-            unblock();
+        if (m_state != Thread::Runnable && m_state != Thread::Running) {
+            ScopedSpinLock lock(m_lock);
+            if (m_blocker && m_blocker->is_reason_signal())
+                unblock();
+        }
     }
 
     else {
@@ -482,6 +487,7 @@ ShouldUnblockThread Thread::dispatch_signal(u8 signal)
             if (!thread_tracer->has_pending_signal(signal)) {
                 m_stop_signal = signal;
                 // make sure SemiPermanentBlocker is unblocked
+                ScopedSpinLock lock(m_lock);
                 if (m_blocker && m_blocker->is_reason_signal())
                     unblock();
                 set_state(Stopped);
@@ -697,13 +703,15 @@ void Thread::set_state(State new_state)
         m_stop_state = m_state;
     }
 
+    auto previous_state = m_state;
     m_state = new_state;
 #ifdef THREAD_DEBUG
     dbg() << "Set Thread " << *this << " state to " << state_string();
 #endif
 
     if (m_process->pid() != 0) {
-        Scheduler::update_state_for_thread(*this);
+        update_state_for_thread(previous_state);
+        ASSERT(g_scheduler_data->has_thread(*this));
     }
 
     if (m_state == Dying && this != Thread::current() && is_finalizable()) {
@@ -713,6 +721,24 @@ void Thread::set_state(State new_state)
     }
 }
 
+void Thread::update_state_for_thread(Thread::State previous_state)
+{
+    ASSERT_INTERRUPTS_DISABLED();
+    ASSERT(g_scheduler_data);
+    ASSERT(g_scheduler_lock.own_lock());
+    auto& previous_list = g_scheduler_data->thread_list_for_state(previous_state);
+    auto& list = g_scheduler_data->thread_list_for_state(state());
+
+    if (&previous_list != &list) {
+        previous_list.remove(*this);
+    }
+
+    if (list.contains(*this))
+        return;
+
+    list.append(*this);
+}
+
 String Thread::backtrace()
 {
     return backtrace_impl();

+ 45 - 15
Kernel/Thread.h

@@ -126,7 +126,11 @@ public:
     class Blocker {
     public:
         virtual ~Blocker() { }
-        virtual bool should_unblock(Thread&, time_t now_s, long us) = 0;
+        virtual bool should_unblock(Thread& thread, time_t, long)
+        {
+            return should_unblock(thread);
+        }
+        virtual bool should_unblock(Thread&) = 0;
         virtual const char* state_string() const = 0;
         virtual bool is_reason_signal() const { return false; }
         void set_interrupted_by_death() { m_was_interrupted_by_death = true; }
@@ -143,7 +147,7 @@ public:
     class JoinBlocker final : public Blocker {
     public:
         explicit JoinBlocker(Thread& joinee, void*& joinee_exit_value);
-        virtual bool should_unblock(Thread&, time_t now_s, long us) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Joining"; }
         void set_joinee_exit_value(void* value) { m_joinee_exit_value = value; }
 
@@ -166,14 +170,14 @@ public:
     class AcceptBlocker final : public FileDescriptionBlocker {
     public:
         explicit AcceptBlocker(const FileDescription&);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Accepting"; }
     };
 
     class ConnectBlocker final : public FileDescriptionBlocker {
     public:
         explicit ConnectBlocker(const FileDescription&);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Connecting"; }
     };
 
@@ -181,6 +185,7 @@ public:
     public:
         explicit WriteBlocker(const FileDescription&);
         virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Writing"; }
 
     private:
@@ -191,6 +196,7 @@ public:
     public:
         explicit ReadBlocker(const FileDescription&);
         virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Reading"; }
 
     private:
@@ -200,7 +206,7 @@ public:
     class ConditionBlocker final : public Blocker {
     public:
         ConditionBlocker(const char* state_string, Function<bool()>&& condition);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return m_state_string; }
 
     private:
@@ -211,7 +217,7 @@ public:
     class SleepBlocker final : public Blocker {
     public:
         explicit SleepBlocker(u64 wakeup_time);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Sleeping"; }
 
     private:
@@ -223,6 +229,7 @@ public:
         typedef Vector<int, FD_SETSIZE> FDVector;
         SelectBlocker(const timespec& ts, bool select_has_timeout, const FDVector& read_fds, const FDVector& write_fds, const FDVector& except_fds);
         virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Selecting"; }
 
     private:
@@ -236,7 +243,7 @@ public:
     class WaitBlocker final : public Blocker {
     public:
         WaitBlocker(int wait_options, pid_t& waitee_pid);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override { return "Waiting"; }
 
     private:
@@ -251,7 +258,7 @@ public:
         };
 
         SemiPermanentBlocker(Reason reason);
-        virtual bool should_unblock(Thread&, time_t, long) override;
+        virtual bool should_unblock(Thread&) override;
         virtual const char* state_string() const override
         {
             switch (m_reason) {
@@ -271,7 +278,11 @@ public:
 
     bool is_stopped() const { return m_state == Stopped; }
     bool is_blocked() const { return m_state == Blocked; }
-    bool has_blocker() const { return m_blocker != nullptr; }
+    bool has_blocker() const
+    {
+        ASSERT(m_lock.own_lock());
+        return m_blocker != nullptr;
+    }
     const Blocker& blocker() const;
 
     u32 cpu() const { return m_cpu.load(AK::MemoryOrder::memory_order_consume); }
@@ -336,17 +347,27 @@ public:
     template<typename T, class... Args>
     [[nodiscard]] BlockResult block(Args&&... args)
     {
-        // We should never be blocking a blocked (or otherwise non-active) thread.
-        ASSERT(state() == Thread::Running);
-        ASSERT(m_blocker == nullptr);
-
         T t(forward<Args>(args)...);
-        m_blocker = &t;
-        set_state(Thread::Blocked);
+
+        {
+            ScopedSpinLock lock(m_lock);
+            // We should never be blocking a blocked (or otherwise non-active) thread.
+            ASSERT(state() == Thread::Running);
+            ASSERT(m_blocker == nullptr);
+
+            if (t.should_unblock(*this)) {
+                // Don't block if the wake condition is already met
+                return BlockResult::NotBlocked;
+            }
+
+            m_blocker = &t;
+            set_state(Thread::Blocked);
+        }
 
         // Yield to the scheduler, and wait for us to resume unblocked.
         yield_without_holding_big_lock();
 
+        ScopedSpinLock lock(m_lock);
         // We should no longer be blocked once we woke up
         ASSERT(state() != Thread::Blocked);
 
@@ -499,6 +520,8 @@ public:
     void stop_tracing();
     void tracer_trap(const RegisterState&);
 
+    RecursiveSpinLock& get_lock() const { return m_lock; }
+
 private:
     IntrusiveListNode m_runnable_list_node;
     IntrusiveListNode m_wait_queue_node;
@@ -511,6 +534,7 @@ private:
     String backtrace_impl();
     void reset_fpu_state();
 
+    mutable RecursiveSpinLock m_lock;
     NonnullRefPtr<Process> m_process;
     int m_tid { -1 };
     TSS32 m_tss;
@@ -567,6 +591,7 @@ private:
     OwnPtr<ThreadTracer> m_tracer;
 
     void yield_without_holding_big_lock();
+    void update_state_for_thread(Thread::State previous_state);
 };
 
 HashTable<Thread*>& thread_table();
@@ -616,6 +641,11 @@ struct SchedulerData {
     ThreadList m_runnable_threads;
     ThreadList m_nonrunnable_threads;
 
+    bool has_thread(Thread& thread) const
+    {
+        return m_runnable_threads.contains(thread) || m_nonrunnable_threads.contains(thread);
+    }
+
     ThreadList& thread_list_for_state(Thread::State state)
     {
         if (Thread::is_runnable_state(state))