Browse Source

ntpquery: Add some more validation of the response header

Nico Weber 4 years ago
parent
commit
0658051996
1 changed files with 24 additions and 3 deletions
  1. 24 3
      Userland/ntpquery.cpp

+ 24 - 3
Userland/ntpquery.cpp

@@ -58,6 +58,10 @@ struct [[gnu::packed]] NtpPacket
     NtpTimestamp origin_timestamp;
     NtpTimestamp receive_timestamp;
     NtpTimestamp transmit_timestamp;
+
+    uint8_t leap_information() const { return li_vn_mode >> 6; }
+    uint8_t version_number() const { return (li_vn_mode >> 3) & 7; }
+    uint8_t mode() const { return li_vn_mode & 7; }
 };
 static_assert(sizeof(NtpPacket) == 48);
 
@@ -218,10 +222,27 @@ int main(int argc, char** argv)
     timeval kernel_receive_time;
     memcpy(&kernel_receive_time, CMSG_DATA(cmsg), sizeof(kernel_receive_time));
 
+    // Checks 3 and 4 from end of section 5 of rfc4330.
+    if (packet.version_number() != 3 && packet.version_number() != 4) {
+        fprintf(stderr, "unexpected version number %d\n", packet.version_number());
+        return 1;
+    }
+    if (packet.mode() != 4) { // 4 means "server", which should be the reply to our 3 ("client") request.
+        fprintf(stderr, "unexpected mode %d\n", packet.mode());
+        return 1;
+    }
+    if (packet.stratum == 0 || packet.stratum >= 16) {
+        fprintf(stderr, "unexpected stratum value %d\n", packet.stratum);
+        return 1;
+    }
     if (packet.origin_timestamp != random_transmit_timestamp) {
         fprintf(stderr, "expected %#016llx as origin timestamp, got %#016llx\n", random_transmit_timestamp, packet.origin_timestamp);
         return 1;
     }
+    if (packet.transmit_timestamp == 0) {
+        fprintf(stderr, "got transmit_timestamp 0\n");
+        return 1;
+    }
 
     NtpTimestamp origin_timestamp = ntp_timestamp_from_timeval(local_transmit_time);
     NtpTimestamp receive_timestamp = be64toh(packet.receive_timestamp);
@@ -242,9 +263,9 @@ int main(int argc, char** argv)
 
     if (verbose) {
         printf("NTP response from %s:\n", inet_ntoa(peer_address.sin_addr));
-        printf("Leap Information: %d\n", packet.li_vn_mode >> 6);
-        printf("Version Number: %d\n", (packet.li_vn_mode >> 3) & 7);
-        printf("Mode: %d\n", packet.li_vn_mode & 7);
+        printf("Leap Information: %d\n", packet.leap_information());
+        printf("Version Number: %d\n", packet.version_number());
+        printf("Mode: %d\n", packet.mode());
         printf("Stratum: %d\n", packet.stratum);
         printf("Poll: %d\n", packet.stratum);
         printf("Precision: %d\n", packet.precision);