TLSv12.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. /*
  2. * Copyright (c) 2020, Ali Mohammad Pur <ali.mpfard@gmail.com>
  3. * All rights reserved.
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * 1. Redistributions of source code must retain the above copyright notice, this
  9. * list of conditions and the following disclaimer.
  10. *
  11. * 2. Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  16. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  19. * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  20. * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  21. * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  22. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  23. * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  24. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. */
  26. #include <LibCore/Timer.h>
  27. #include <LibCrypto/ASN1/DER.h>
  28. #include <LibCrypto/PK/Code/EMSA_PSS.h>
  29. #include <LibTLS/TLSv12.h>
  30. //#define TLS_DEBUG
  31. namespace {
  32. struct OIDChain {
  33. void* root { nullptr };
  34. u8* oid { nullptr };
  35. };
  36. }
  37. namespace TLS {
  38. // "for now" q&d implementation of ASN1
  39. namespace {
  40. static bool _asn1_is_field_present(const u32* fields, const u32* prefix)
  41. {
  42. size_t i = 0;
  43. while (prefix[i]) {
  44. if (fields[i] != prefix[i])
  45. return false;
  46. ++i;
  47. }
  48. return true;
  49. }
  50. static bool _asn1_is_oid(const u8* oid, const u8* compare, size_t length = 3)
  51. {
  52. size_t i = 0;
  53. while (oid[i] && i < length) {
  54. if (oid[i] != compare[i])
  55. return false;
  56. ++i;
  57. }
  58. return true;
  59. }
  60. static void _set_algorithm(u32&, const u8* value, size_t length)
  61. {
  62. if (length != 9) {
  63. dbg() << "unsupported algorithm " << value;
  64. }
  65. dbg() << "FIXME: Set algorithm";
  66. }
  67. static size_t _get_asn1_length(const u8* buffer, size_t length, size_t& octets)
  68. {
  69. octets = 0;
  70. if (length < 1)
  71. return 0;
  72. u8 size = buffer[0];
  73. if (size & 0x80) {
  74. octets = size & 0x7f;
  75. if (octets > length - 1) {
  76. return 0;
  77. }
  78. auto reference_octets = octets;
  79. if (octets > 4)
  80. reference_octets = 4;
  81. size_t long_size = 0, coeff = 1;
  82. for (auto i = reference_octets; i > 0; --i) {
  83. long_size += buffer[i] * coeff;
  84. coeff *= 0x100;
  85. }
  86. ++octets;
  87. return long_size;
  88. }
  89. ++octets;
  90. return size;
  91. }
  92. static ssize_t _parse_asn1(const Context& context, Certificate& cert, const u8* buffer, size_t size, int level, u32* fields, u8* has_key, int client_cert, u8* root_oid, OIDChain* chain)
  93. {
  94. OIDChain local_chain;
  95. local_chain.root = chain;
  96. size_t position = 0;
  97. // parse DER...again
  98. size_t index = 0;
  99. u8 oid[16] { 0 };
  100. local_chain.oid = oid;
  101. if (has_key)
  102. *has_key = 0;
  103. u8 local_has_key = 0;
  104. const u8* cert_data = nullptr;
  105. size_t cert_length = 0;
  106. while (position < size) {
  107. size_t start_position = position;
  108. if (size - position < 2) {
  109. dbg() << "not enough data for certificate size";
  110. return (i8)Error::NeedMoreData;
  111. }
  112. u8 first = buffer[position++];
  113. u8 type = first & 0x1f;
  114. u8 constructed = first & 0x20;
  115. size_t octets = 0;
  116. u32 temp;
  117. index++;
  118. if (level <= 0xff)
  119. fields[level - 1] = index;
  120. size_t length = _get_asn1_length((const u8*)&buffer[position], size - position, octets);
  121. if (octets > 4 || octets > size - position) {
  122. dbg() << "could not read the certificate";
  123. return position;
  124. }
  125. position += octets;
  126. if (size - position < length) {
  127. dbg() << "not enough data for sequence";
  128. return (i8)Error::NeedMoreData;
  129. }
  130. if (length && constructed) {
  131. switch (type) {
  132. case 0x03:
  133. break;
  134. case 0x10:
  135. if (level == 2 && index == 1) {
  136. cert_length = length + position - start_position;
  137. cert_data = buffer + start_position;
  138. }
  139. // public key data
  140. if (!cert.version && _asn1_is_field_present(fields, Constants::priv_der_id)) {
  141. temp = length + position - start_position;
  142. if (cert.der.size() < temp) {
  143. cert.der.grow(temp);
  144. } else {
  145. cert.der.trim(temp);
  146. }
  147. cert.der.overwrite(0, buffer + start_position, temp);
  148. }
  149. break;
  150. default:
  151. break;
  152. }
  153. local_has_key = false;
  154. _parse_asn1(context, cert, buffer + position, length, level + 1, fields, &local_has_key, client_cert, root_oid, &local_chain);
  155. if ((local_has_key && (!context.is_server || client_cert)) || (client_cert || _asn1_is_field_present(fields, Constants::pk_id))) {
  156. temp = length + position - start_position;
  157. if (cert.der.size() < temp) {
  158. cert.der.grow(temp);
  159. } else {
  160. cert.der.trim(temp);
  161. }
  162. cert.der.overwrite(0, buffer + start_position, temp);
  163. }
  164. } else {
  165. switch (type) {
  166. case 0x00:
  167. return position;
  168. break;
  169. case 0x01:
  170. temp = buffer[position];
  171. break;
  172. case 0x02:
  173. if (_asn1_is_field_present(fields, Constants::pk_id)) {
  174. if (has_key)
  175. *has_key = true;
  176. if (index == 1)
  177. cert.public_key.set(
  178. Crypto::UnsignedBigInteger::import_data(buffer + position, length),
  179. cert.public_key.public_exponent());
  180. else if (index == 2)
  181. cert.public_key.set(
  182. cert.public_key.modulus(),
  183. Crypto::UnsignedBigInteger::import_data(buffer + position, length));
  184. } else if (_asn1_is_field_present(fields, Constants::serial_id)) {
  185. cert.serial_number = Crypto::UnsignedBigInteger::import_data(buffer + position, length);
  186. }
  187. if (_asn1_is_field_present(fields, Constants::version_id)) {
  188. if (length == 1)
  189. cert.version = buffer[position];
  190. }
  191. // print_buffer(ByteBuffer::wrap(buffer + position, length));
  192. break;
  193. case 0x03:
  194. if (_asn1_is_field_present(fields, Constants::pk_id)) {
  195. if (has_key)
  196. *has_key = true;
  197. }
  198. if (_asn1_is_field_present(fields, Constants::sign_id)) {
  199. auto* value = buffer + position;
  200. auto len = length;
  201. if (!value[0] && len % 2) {
  202. ++value;
  203. --len;
  204. }
  205. cert.sign_key = ByteBuffer::copy(value, len);
  206. } else {
  207. if (buffer[position] == 0 && length > 256) {
  208. _parse_asn1(context, cert, buffer + position + 1, length - 1, level + 1, fields, &local_has_key, client_cert, root_oid, &local_chain);
  209. } else {
  210. _parse_asn1(context, cert, buffer + position, length, level + 1, fields, &local_has_key, client_cert, root_oid, &local_chain);
  211. }
  212. }
  213. break;
  214. case 0x04:
  215. _parse_asn1(context, cert, buffer + position, length, level + 1, fields, &local_has_key, client_cert, root_oid, &local_chain);
  216. break;
  217. case 0x05:
  218. break;
  219. case 0x06:
  220. if (_asn1_is_field_present(fields, Constants::pk_id)) {
  221. _set_algorithm(cert.key_algorithm, buffer + position, length);
  222. }
  223. if (_asn1_is_field_present(fields, Constants::algorithm_id)) {
  224. _set_algorithm(cert.algorithm, buffer + position, length);
  225. }
  226. if (length < 16)
  227. memcpy(oid, buffer + position, length);
  228. else
  229. memcpy(oid, buffer + position, 16);
  230. if (root_oid)
  231. memcpy(root_oid, oid, 16);
  232. break;
  233. case 0x09:
  234. break;
  235. case 0x17:
  236. case 0x018:
  237. // time
  238. // ignore
  239. break;
  240. case 0x013:
  241. case 0x0c:
  242. case 0x14:
  243. case 0x15:
  244. case 0x16:
  245. case 0x19:
  246. case 0x1a:
  247. case 0x1b:
  248. case 0x1c:
  249. case 0x1d:
  250. case 0x1e:
  251. // printable string and such
  252. if (_asn1_is_field_present(fields, Constants::issurer_id)) {
  253. if (_asn1_is_oid(oid, Constants::country_oid)) {
  254. cert.issuer_country = String { (const char*)buffer + position, length };
  255. } else if (_asn1_is_oid(oid, Constants::state_oid)) {
  256. cert.issuer_state = String { (const char*)buffer + position, length };
  257. } else if (_asn1_is_oid(oid, Constants::location_oid)) {
  258. cert.issuer_location = String { (const char*)buffer + position, length };
  259. } else if (_asn1_is_oid(oid, Constants::entity_oid)) {
  260. cert.issuer_entity = String { (const char*)buffer + position, length };
  261. } else if (_asn1_is_oid(oid, Constants::subject_oid)) {
  262. cert.issuer_subject = String { (const char*)buffer + position, length };
  263. }
  264. } else if (_asn1_is_field_present(fields, Constants::owner_id)) {
  265. if (_asn1_is_oid(oid, Constants::country_oid)) {
  266. cert.country = String { (const char*)buffer + position, length };
  267. } else if (_asn1_is_oid(oid, Constants::state_oid)) {
  268. cert.state = String { (const char*)buffer + position, length };
  269. } else if (_asn1_is_oid(oid, Constants::location_oid)) {
  270. cert.location = String { (const char*)buffer + position, length };
  271. } else if (_asn1_is_oid(oid, Constants::entity_oid)) {
  272. cert.entity = String { (const char*)buffer + position, length };
  273. } else if (_asn1_is_oid(oid, Constants::subject_oid)) {
  274. cert.subject = String { (const char*)buffer + position, length };
  275. }
  276. }
  277. break;
  278. default:
  279. // dbg() << "unused field " << type;
  280. break;
  281. }
  282. }
  283. position += length;
  284. }
  285. if (level == 2 && cert.sign_key.size() && cert_length && cert_data) {
  286. dbg() << "FIXME: Cert.fingerprint";
  287. }
  288. return position;
  289. }
  290. }
  291. Optional<Certificate> TLSv12::parse_asn1(const ByteBuffer& buffer, bool) const
  292. {
  293. // FIXME: Our ASN.1 parser is not quite up to the task of
  294. // parsing this X.509 certificate, so for the
  295. // time being, we will "parse" the certificate
  296. // manually right here.
  297. Certificate cert;
  298. u32 fields[0xff];
  299. _parse_asn1(m_context, cert, buffer.data(), buffer.size(), 1, fields, nullptr, 0, nullptr, nullptr);
  300. #ifdef TLS_DEBUG
  301. dbg() << "Certificate issued for " << cert.subject << " by " << cert.issuer_subject;
  302. #endif
  303. return cert;
  304. }
  305. ssize_t TLSv12::handle_certificate(const ByteBuffer& buffer)
  306. {
  307. ssize_t res = 0;
  308. if (buffer.size() < 3) {
  309. dbg() << "not enough certificate header data";
  310. return (i8)Error::NeedMoreData;
  311. }
  312. u32 certificate_total_length = buffer[0] * 0x10000 + buffer[1] * 0x100 + buffer[2];
  313. dbg() << "total length: " << certificate_total_length;
  314. if (certificate_total_length <= 4)
  315. return 3 * certificate_total_length;
  316. res += 3;
  317. if (certificate_total_length > buffer.size() - res) {
  318. dbg() << "not enough data for claimed total cert length";
  319. return (i8)Error::NeedMoreData;
  320. }
  321. size_t size = certificate_total_length;
  322. size_t index = 0;
  323. bool valid_certificate = false;
  324. while (size > 0) {
  325. ++index;
  326. if (buffer.size() - res < 3) {
  327. dbg() << "not enough data for certificate length";
  328. return (i8)Error::NeedMoreData;
  329. }
  330. size_t certificate_size = buffer[res] * 0x10000 + buffer[res + 1] * 0x100 + buffer[res + 2];
  331. res += 3;
  332. if (buffer.size() - res < certificate_size) {
  333. dbg() << "not enough data for certificate body";
  334. return (i8)Error::NeedMoreData;
  335. }
  336. auto res_cert = res;
  337. auto remaining = certificate_size;
  338. size_t certificates_in_chain = 0;
  339. do {
  340. if (remaining <= 3)
  341. break;
  342. ++certificates_in_chain;
  343. if (buffer.size() < (size_t)res_cert + 3)
  344. break;
  345. size_t certificate_size_specific = buffer[res_cert] * 0x10000 + buffer[res_cert + 1] * 0x100 + buffer[res_cert + 2];
  346. res_cert += 3;
  347. remaining -= 3;
  348. if (certificate_size_specific > remaining) {
  349. dbg() << "invalid certificate size (expected " << remaining << " but got " << certificate_size_specific << ")";
  350. break;
  351. }
  352. remaining -= certificate_size_specific;
  353. auto certificate = parse_asn1(buffer.slice_view(res_cert, certificate_size_specific), false);
  354. if (certificate.has_value()) {
  355. m_context.certificates.append(certificate.value());
  356. valid_certificate = true;
  357. }
  358. res_cert += certificate_size;
  359. } while (remaining > 0);
  360. if (remaining) {
  361. dbg() << "extraneous " << remaining << " bytes left over after parsing certificates";
  362. }
  363. size -= certificate_size + 3;
  364. res += certificate_size;
  365. }
  366. if (!valid_certificate)
  367. return (i8)Error::UnsupportedCertificate;
  368. if ((size_t)res != buffer.size())
  369. dbg() << "some data left unread: " << (size_t)res << " bytes out of " << buffer.size();
  370. return res;
  371. }
  372. void TLSv12::consume(const ByteBuffer& record)
  373. {
  374. if (m_context.critical_error) {
  375. dbg() << "There has been a critical error (" << (i8)m_context.critical_error << "), refusing to continue";
  376. return;
  377. }
  378. if (record.size() == 0) {
  379. return;
  380. }
  381. #ifdef TLS_DEBUG
  382. dbg() << "Consuming " << record.size() << " bytes";
  383. #endif
  384. m_context.message_buffer.append(record.data(), record.size());
  385. size_t index { 0 };
  386. size_t buffer_length = m_context.message_buffer.size();
  387. size_t size_offset { 3 }; // read the common record header
  388. size_t header_size { 5 };
  389. #ifdef TLS_DEBUG
  390. dbg() << "message buffer length " << buffer_length;
  391. #endif
  392. while (buffer_length >= 5) {
  393. auto length = convert_between_host_and_network(*(u16*)m_context.message_buffer.offset_pointer(index + size_offset)) + header_size;
  394. if (length > buffer_length) {
  395. #ifdef TLS_DEBUG
  396. dbg() << "Need more data: " << length << " | " << buffer_length;
  397. #endif
  398. break;
  399. }
  400. auto consumed = handle_message(m_context.message_buffer.slice_view(index, length));
  401. #ifdef TLS_DEBUG
  402. if (consumed > 0)
  403. dbg() << "consumed " << (size_t)consumed << " bytes";
  404. else
  405. dbg() << "error: " << (int)consumed;
  406. #endif
  407. if (consumed != (i8)Error::NeedMoreData) {
  408. if (consumed < 0) {
  409. dbg() << "Consumed an error: " << (int)consumed;
  410. if (!m_context.critical_error)
  411. m_context.critical_error = (i8)consumed;
  412. m_context.error_code = (Error)consumed;
  413. break;
  414. }
  415. } else {
  416. continue;
  417. }
  418. index += length;
  419. buffer_length -= length;
  420. if (m_context.critical_error) {
  421. dbg() << "Broken connection";
  422. m_context.error_code = Error::BrokenConnection;
  423. break;
  424. }
  425. }
  426. if (m_context.error_code != Error::NoError && m_context.error_code != Error::NeedMoreData) {
  427. dbg() << "consume error: " << (i8)m_context.error_code;
  428. m_context.message_buffer.clear();
  429. return;
  430. }
  431. if (index) {
  432. m_context.message_buffer = m_context.message_buffer.slice(index, m_context.message_buffer.size() - index);
  433. }
  434. }
  435. void TLSv12::ensure_hmac(size_t digest_size, bool local)
  436. {
  437. if (local && m_hmac_local)
  438. return;
  439. if (!local && m_hmac_remote)
  440. return;
  441. auto hash_kind = Crypto::Hash::HashKind::None;
  442. switch (digest_size) {
  443. case Crypto::Hash::SHA1::DigestSize:
  444. hash_kind = Crypto::Hash::HashKind::SHA1;
  445. break;
  446. case Crypto::Hash::SHA256::DigestSize:
  447. hash_kind = Crypto::Hash::HashKind::SHA256;
  448. break;
  449. case Crypto::Hash::SHA512::DigestSize:
  450. hash_kind = Crypto::Hash::HashKind::SHA512;
  451. break;
  452. default:
  453. dbg() << "Failed to find a suitable hash for size " << digest_size;
  454. break;
  455. }
  456. auto hmac = make<Crypto::Authentication::HMAC<Crypto::Hash::Manager>>(ByteBuffer::wrap(local ? m_context.crypto.local_mac : m_context.crypto.remote_mac, digest_size), hash_kind);
  457. if (local)
  458. m_hmac_local = move(hmac);
  459. else
  460. m_hmac_remote = move(hmac);
  461. }
  462. }