Переглянути джерело

LibCrypto+LibTLS: Avoid unaligned reads and writes

This adds an `AK::ByteReader` to help with that so we don't duplicate
the logic all over the place.
No more `*(const u16*)` and `*(const u32*)` for anyone.
This should help a little with #7060.
Ali Mohammad Pur 4 роки тому
батько
коміт
df515e1d85

+ 69 - 0
AK/ByteReader.h

@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2021, Ali Mohammad Pur <mpfard@serenityos.org>
+ *
+ * SPDX-License-Identifier: BSD-2-Clause
+ */
+
+#pragma once
+
+#include <AK/Types.h>
+
+namespace AK {
+
+struct ByteReader {
+    static void store(u8* address, u16 value)
+    {
+        union {
+            u16 _16;
+            u8 _8[2];
+        } const v { ._16 = value };
+        __builtin_memcpy(address, v._8, 2);
+    }
+
+    static void store(u8* address, u32 value)
+    {
+        union {
+            u32 _32;
+            u8 _8[4];
+        } const v { ._32 = value };
+        __builtin_memcpy(address, v._8, 4);
+    }
+
+    static void load(const u8* address, u16& value)
+    {
+        union {
+            u16 _16;
+            u8 _8[2];
+        } v { ._16 = 0 };
+        __builtin_memcpy(&v._8, address, 2);
+        value = v._16;
+    }
+
+    static void load(const u8* address, u32& value)
+    {
+        union {
+            u32 _32;
+            u8 _8[4];
+        } v { ._32 = 0 };
+        __builtin_memcpy(&v._8, address, 4);
+        value = v._32;
+    }
+
+    static u16 load16(const u8* address)
+    {
+        u16 value;
+        load(address, value);
+        return value;
+    }
+
+    static u32 load32(const u8* address)
+    {
+        u32 value;
+        load(address, value);
+        return value;
+    }
+};
+
+}
+
+using AK::ByteReader;

+ 3 - 3
Userland/Libraries/LibCrypto/Authentication/GHash.cpp

@@ -4,6 +4,7 @@
  * SPDX-License-Identifier: BSD-2-Clause
  */
 
+#include <AK/ByteReader.h>
 #include <AK/Debug.h>
 #include <AK/MemoryStream.h>
 #include <AK/Types.h>
@@ -15,14 +16,13 @@ namespace {
 
 static u32 to_u32(const u8* b)
 {
-    return AK::convert_between_host_and_big_endian(*(const u32*)b);
+    return AK::convert_between_host_and_big_endian(ByteReader::load32(b));
 }
 
 static void to_u8s(u8* b, const u32* w)
 {
     for (auto i = 0; i < 4; ++i) {
-        auto& e = *((u32*)(b + i * 4));
-        e = AK::convert_between_host_and_big_endian(w[i]);
+        ByteReader::store(b + i * 4, AK::convert_between_host_and_big_endian(w[i]));
     }
 }
 

+ 8 - 8
Userland/Libraries/LibTLS/ClientHandshake.cpp

@@ -53,7 +53,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
         dbgln("not enough data for version");
         return (i8)Error::NeedMoreData;
     }
-    auto version = (Version)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
+    auto version = static_cast<Version>(AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res))));
 
     res += 2;
     if (!supports_version(version))
@@ -84,7 +84,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
         dbgln("not enough data for cipher suite listing");
         return (i8)Error::NeedMoreData;
     }
-    auto cipher = (CipherSuite)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
+    auto cipher = static_cast<CipherSuite>(AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res))));
     res += 2;
     if (!supports_cipher(cipher)) {
         m_context.cipher = CipherSuite::Invalid;
@@ -113,14 +113,14 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
 
     // Presence of extensions is determined by availability of bytes after compression_method
     if (buffer.size() - res >= 2) {
-        auto extensions_bytes_total = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
+        auto extensions_bytes_total = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
         dbgln_if(TLS_DEBUG, "Extensions bytes total: {}", extensions_bytes_total);
     }
 
     while (buffer.size() - res >= 4) {
-        auto extension_type = (HandshakeExtension)AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
+        auto extension_type = (HandshakeExtension)AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
         res += 2;
-        u16 extension_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
+        u16 extension_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
         res += 2;
 
         dbgln_if(TLS_DEBUG, "Extension {} with length {}", (u16)extension_type, extension_length);
@@ -134,14 +134,14 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
                 // ServerNameList total size
                 if (buffer.size() - res < 2)
                     return (i8)Error::NeedMoreData;
-                auto sni_name_list_bytes = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
+                auto sni_name_list_bytes = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
                 dbgln_if(TLS_DEBUG, "SNI: expecting ServerNameList of {} bytes", sni_name_list_bytes);
 
                 // Exactly one ServerName should be present
                 if (buffer.size() - res < 3)
                     return (i8)Error::NeedMoreData;
                 auto sni_name_type = (NameType)(*(const u8*)buffer.offset_pointer(res++));
-                auto sni_name_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res += 2));
+                auto sni_name_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res += 2)));
 
                 if (sni_name_type != NameType::HostName)
                     return (i8)Error::NotUnderstood;
@@ -158,7 +158,7 @@ ssize_t TLSv12::handle_hello(ReadonlyBytes buffer, WritePacketStage& write_packe
             }
         } else if (extension_type == HandshakeExtension::ApplicationLayerProtocolNegotiation && m_context.alpn.size()) {
             if (buffer.size() - res > 2) {
-                auto alpn_length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(res));
+                auto alpn_length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(res)));
                 if (alpn_length && alpn_length <= extension_length - 2) {
                     const u8* alpn = buffer.offset_pointer(res + 2);
                     size_t alpn_position = 0;

+ 5 - 4
Userland/Libraries/LibTLS/Record.cpp

@@ -34,7 +34,7 @@ void TLSv12::write_packet(ByteBuffer& packet)
 void TLSv12::update_packet(ByteBuffer& packet)
 {
     u32 header_size = 5;
-    *(u16*)packet.offset_pointer(3) = AK::convert_between_host_and_network_endian((u16)(packet.size() - header_size));
+    ByteReader::store(packet.offset_pointer(3), AK::convert_between_host_and_network_endian((u16)(packet.size() - header_size)));
 
     if (packet[0] != (u8)MessageType::ChangeCipher) {
         if (packet[0] == (u8)MessageType::Handshake && packet.size() > header_size) {
@@ -159,7 +159,7 @@ void TLSv12::update_packet(ByteBuffer& packet)
                 // store the correct ciphertext length into the packet
                 u16 ct_length = (u16)ct.size() - header_size;
 
-                *(u16*)ct.offset_pointer(header_size - 2) = AK::convert_between_host_and_network_endian(ct_length);
+                ByteReader::store(ct.offset_pointer(header_size - 2), AK::convert_between_host_and_network_endian(ct_length));
 
                 // replace the packet with the ciphertext
                 packet = ct;
@@ -222,13 +222,14 @@ ssize_t TLSv12::handle_message(ReadonlyBytes buffer)
     // FIXME: Read the version and verify it
 
     if constexpr (TLS_DEBUG) {
-        auto version = (Version) * (const u16*)buffer.offset_pointer(buffer_position);
+        auto version = ByteReader::load16(buffer.offset_pointer(buffer_position));
         dbgln("type={}, version={}", (u8)type, (u16)version);
     }
 
     buffer_position += 2;
 
-    auto length = AK::convert_between_host_and_network_endian(*(const u16*)buffer.offset_pointer(buffer_position));
+    auto length = AK::convert_between_host_and_network_endian(ByteReader::load16(buffer.offset_pointer(buffer_position)));
+
     dbgln_if(TLS_DEBUG, "record length: {} at offset: {}", length, buffer_position);
     buffer_position += 2;
 

+ 2 - 1
Userland/Libraries/LibTLS/TLSPacketBuilder.h

@@ -7,6 +7,7 @@
 #pragma once
 
 #include <AK/ByteBuffer.h>
+#include <AK/ByteReader.h>
 #include <AK/Endian.h>
 #include <AK/Types.h>
 
@@ -38,7 +39,7 @@ public:
         m_packet_data = ByteBuffer::create_uninitialized(size_hint + 16);
         m_current_length = 5;
         m_packet_data[0] = (u8)type;
-        *(u16*)m_packet_data.offset_pointer(1) = AK::convert_between_host_and_network_endian((u16)version);
+        ByteReader::store(m_packet_data.offset_pointer(1), AK::convert_between_host_and_network_endian((u16)version));
     }
 
     inline void append(u16 value)

+ 1 - 1
Userland/Libraries/LibTLS/TLSv12.cpp

@@ -601,7 +601,7 @@ void TLSv12::consume(ReadonlyBytes record)
     dbgln_if(TLS_DEBUG, "message buffer length {}", buffer_length);
 
     while (buffer_length >= 5) {
-        auto length = AK::convert_between_host_and_network_endian(*(u16*)m_context.message_buffer.offset_pointer(index + size_offset)) + header_size;
+        auto length = AK::convert_between_host_and_network_endian(ByteReader::load16(m_context.message_buffer.offset_pointer(index + size_offset))) + header_size;
         if (length > buffer_length) {
             dbgln_if(TLS_DEBUG, "Need more data: {} > {}", length, buffer_length);
             break;