Переглянути джерело

AK: Disallow returning of string literals for errors in kernel code

This code should not be used in the kernel - we should always propagate
proper errno codes in case we need to return those to userland so it
could decode it in a reasonable way.
Liav A 2 роки тому
батько
коміт
048fb2c204
5 змінених файлів з 51 додано та 20 видалено
  1. 28 5
      AK/Error.h
  2. 1 3
      AK/Format.h
  3. 3 3
      AK/Hex.cpp
  4. 5 5
      AK/MemoryStream.cpp
  5. 14 4
      AK/Stream.cpp

+ 28 - 5
AK/Error.h

@@ -31,7 +31,12 @@ public:
     // For calling this method from userspace programs, we will simply return from
     // the Error::from_string_view method!
     [[nodiscard]] static Error from_string_view_or_print_error_and_return_errno(StringView string_literal, int code);
-    [[nodiscard]] static Error from_syscall(StringView syscall_name, int rc) { return Error(syscall_name, rc); }
+
+#ifndef KERNEL
+    [[nodiscard]] static Error from_syscall(StringView syscall_name, int rc)
+    {
+        return Error(syscall_name, rc);
+    }
     [[nodiscard]] static Error from_string_view(StringView string_literal) { return Error(string_literal); }
 
     [[nodiscard]] static Error copy(Error const& error)
@@ -57,17 +62,29 @@ public:
     {
         return from_string_view(string);
     }
+#endif
 
     bool operator==(Error const& other) const
     {
+#ifdef KERNEL
+        return m_code == other.m_code;
+#else
         return m_code == other.m_code && m_string_literal == other.m_string_literal && m_syscall == other.m_syscall;
+#endif
     }
 
-    bool is_errno() const { return m_code != 0; }
-    bool is_syscall() const { return m_syscall; }
-
     int code() const { return m_code; }
-    StringView string_literal() const { return m_string_literal; }
+#ifndef KERNEL
+    bool is_errno() const
+    {
+        return m_code != 0;
+    }
+    bool is_syscall() const { return m_syscall; }
+    StringView string_literal() const
+    {
+        return m_string_literal;
+    }
+#endif
 
 protected:
     Error(int code)
@@ -76,6 +93,7 @@ protected:
     }
 
 private:
+#ifndef KERNEL
     Error(StringView string_literal)
         : m_string_literal(string_literal)
     {
@@ -92,8 +110,13 @@ private:
     Error& operator=(Error const&) = default;
 
     StringView m_string_literal;
+#endif
+
     int m_code { 0 };
+
+#ifndef KERNEL
     bool m_syscall { false };
+#endif
 };
 
 template<typename T, typename E>

+ 1 - 3
AK/Format.h

@@ -684,9 +684,7 @@ struct Formatter<Error> : Formatter<FormatString> {
     ErrorOr<void> format(FormatBuilder& builder, Error const& error)
     {
 #if defined(AK_OS_SERENITY) && defined(KERNEL)
-        if (error.is_errno())
-            return Formatter<FormatString>::format(builder, "Error(errno={})"sv, error.code());
-        return Formatter<FormatString>::format(builder, "Error({})"sv, error.string_literal());
+        return Formatter<FormatString>::format(builder, "Error(errno={})"sv, error.code());
 #else
         if (error.is_syscall())
             return Formatter<FormatString>::format(builder, "{}: {} (errno={})"sv, error.string_literal(), strerror(error.code()), error.code());

+ 3 - 3
AK/Hex.cpp

@@ -15,18 +15,18 @@ namespace AK {
 ErrorOr<ByteBuffer> decode_hex(StringView input)
 {
     if ((input.length() % 2) != 0)
-        return Error::from_string_literal("Hex string was not an even length");
+        return Error::from_string_view_or_print_error_and_return_errno("Hex string was not an even length"sv, EINVAL);
 
     auto output = TRY(ByteBuffer::create_zeroed(input.length() / 2));
 
     for (size_t i = 0; i < input.length() / 2; ++i) {
         auto const c1 = decode_hex_digit(input[i * 2]);
         if (c1 >= 16)
-            return Error::from_string_literal("Hex string contains invalid digit");
+            return Error::from_string_view_or_print_error_and_return_errno("Hex string contains invalid digit"sv, EINVAL);
 
         auto const c2 = decode_hex_digit(input[i * 2 + 1]);
         if (c2 >= 16)
-            return Error::from_string_literal("Hex string contains invalid digit");
+            return Error::from_string_view_or_print_error_and_return_errno("Hex string contains invalid digit"sv, EINVAL);
 
         output[i] = (c1 << 4) + c2;
     }

+ 5 - 5
AK/MemoryStream.cpp

@@ -59,19 +59,19 @@ ErrorOr<size_t> FixedMemoryStream::seek(i64 offset, SeekMode seek_mode)
     switch (seek_mode) {
     case SeekMode::SetPosition:
         if (offset > static_cast<i64>(m_bytes.size()))
-            return Error::from_string_literal("Offset past the end of the stream memory");
+            return Error::from_string_view_or_print_error_and_return_errno("Offset past the end of the stream memory"sv, EINVAL);
 
         m_offset = offset;
         break;
     case SeekMode::FromCurrentPosition:
         if (offset + static_cast<i64>(m_offset) > static_cast<i64>(m_bytes.size()))
-            return Error::from_string_literal("Offset past the end of the stream memory");
+            return Error::from_string_view_or_print_error_and_return_errno("Offset past the end of the stream memory"sv, EINVAL);
 
         m_offset += offset;
         break;
     case SeekMode::FromEndPosition:
         if (offset > static_cast<i64>(m_bytes.size()))
-            return Error::from_string_literal("Offset past the start of the stream memory");
+            return Error::from_string_view_or_print_error_and_return_errno("Offset past the start of the stream memory"sv, EINVAL);
 
         m_offset = m_bytes.size() - offset;
         break;
@@ -92,7 +92,7 @@ ErrorOr<size_t> FixedMemoryStream::write(ReadonlyBytes bytes)
 ErrorOr<void> FixedMemoryStream::write_entire_buffer(ReadonlyBytes bytes)
 {
     if (remaining() < bytes.size())
-        return Error::from_string_literal("Write of entire buffer ends past the memory area");
+        return Error::from_string_view_or_print_error_and_return_errno("Write of entire buffer ends past the memory area"sv, EINVAL);
 
     TRY(write(bytes));
     return {};
@@ -163,7 +163,7 @@ ErrorOr<void> AllocatingMemoryStream::discard(size_t count)
     VERIFY(m_write_offset >= m_read_offset);
 
     if (count > used_buffer_size())
-        return Error::from_string_literal("Number of discarded bytes is higher than the number of allocated bytes");
+        return Error::from_string_view_or_print_error_and_return_errno("Number of discarded bytes is higher than the number of allocated bytes"sv, EINVAL);
 
     m_read_offset += count;
 

+ 14 - 4
AK/Stream.cpp

@@ -16,14 +16,19 @@ ErrorOr<void> Stream::read_entire_buffer(Bytes buffer)
     size_t nread = 0;
     while (nread < buffer.size()) {
         if (is_eof())
-            return Error::from_string_literal("Reached end-of-file before filling the entire buffer");
+            return Error::from_string_view_or_print_error_and_return_errno("Reached end-of-file before filling the entire buffer"sv, EIO);
 
         auto result = read(buffer.slice(nread));
         if (result.is_error()) {
+#ifdef KERNEL
+            if (result.error().code() == EINTR) {
+                continue;
+            }
+#else
             if (result.error().is_errno() && result.error().code() == EINTR) {
                 continue;
             }
-
+#endif
             return result.release_error();
         }
 
@@ -69,7 +74,7 @@ ErrorOr<void> Stream::discard(size_t discarded_bytes)
 
     while (discarded_bytes > 0) {
         if (is_eof())
-            return Error::from_string_literal("Reached end-of-file before reading all discarded bytes");
+            return Error::from_string_view_or_print_error_and_return_errno("Reached end-of-file before reading all discarded bytes"sv, EIO);
 
         auto slice = TRY(read(buffer.span().slice(0, min(discarded_bytes, continuous_read_size))));
         discarded_bytes -= slice.size();
@@ -84,10 +89,15 @@ ErrorOr<void> Stream::write_entire_buffer(ReadonlyBytes buffer)
     while (nwritten < buffer.size()) {
         auto result = write(buffer.slice(nwritten));
         if (result.is_error()) {
+#ifdef KERNEL
+            if (result.error().code() == EINTR) {
+                continue;
+            }
+#else
             if (result.error().is_errno() && result.error().code() == EINTR) {
                 continue;
             }
-
+#endif
             return result.release_error();
         }