Matrix.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /*
  2. * Copyright (c) 2020, the SerenityOS developers.
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #pragma once
  7. #include <AK/Types.h>
  8. #include <initializer_list>
  9. namespace Gfx {
  10. template<size_t N, typename T>
  11. class Matrix {
  12. public:
  13. static constexpr size_t Size = N;
  14. constexpr Matrix() = default;
  15. constexpr Matrix(std::initializer_list<T> elements)
  16. {
  17. VERIFY(elements.size() == N * N);
  18. size_t i = 0;
  19. for (auto& element : elements) {
  20. m_elements[i / N][i % N] = element;
  21. ++i;
  22. }
  23. }
  24. template<typename... Args>
  25. constexpr Matrix(Args... args)
  26. : Matrix({ (T)args... })
  27. {
  28. }
  29. Matrix(const Matrix& other)
  30. {
  31. __builtin_memcpy(m_elements, other.elements(), sizeof(T) * N * N);
  32. }
  33. constexpr auto elements() const { return m_elements; }
  34. constexpr auto elements() { return m_elements; }
  35. constexpr Matrix operator*(const Matrix& other) const
  36. {
  37. Matrix product;
  38. for (size_t i = 0; i < N; ++i) {
  39. for (size_t j = 0; j < N; ++j) {
  40. auto& element = product.m_elements[i][j];
  41. if constexpr (N == 4) {
  42. element = m_elements[i][0] * other.m_elements[0][j]
  43. + m_elements[i][1] * other.m_elements[1][j]
  44. + m_elements[i][2] * other.m_elements[2][j]
  45. + m_elements[i][3] * other.m_elements[3][j];
  46. } else if constexpr (N == 3) {
  47. element = m_elements[i][0] * other.m_elements[0][j]
  48. + m_elements[i][1] * other.m_elements[1][j]
  49. + m_elements[i][2] * other.m_elements[2][j];
  50. } else if constexpr (N == 2) {
  51. element = m_elements[i][0] * other.m_elements[0][j]
  52. + m_elements[i][1] * other.m_elements[1][j];
  53. } else if constexpr (N == 1) {
  54. element = m_elements[i][0] * other.m_elements[0][j];
  55. } else {
  56. T value {};
  57. for (size_t k = 0; k < N; ++k)
  58. value += m_elements[i][k] * other.m_elements[k][j];
  59. element = value;
  60. }
  61. }
  62. }
  63. return product;
  64. }
  65. constexpr Matrix operator/(T divisor) const
  66. {
  67. Matrix division;
  68. for (size_t i = 0; i < N; ++i) {
  69. for (size_t j = 0; j < N; ++j) {
  70. division.m_elements[i][j] = m_elements[i][j] / divisor;
  71. }
  72. }
  73. return division;
  74. }
  75. constexpr Matrix adjugate() const
  76. {
  77. if constexpr (N == 1)
  78. return Matrix(1);
  79. Matrix adjugate;
  80. for (size_t i = 0; i < N; ++i) {
  81. for (size_t j = 0; j < N; ++j) {
  82. int sign = (i + j) % 2 == 0 ? 1 : -1;
  83. adjugate.m_elements[j][i] = sign * first_minor(i, j);
  84. }
  85. }
  86. return adjugate;
  87. }
  88. constexpr T determinant() const
  89. {
  90. if constexpr (N == 1) {
  91. return m_elements[0][0];
  92. } else {
  93. T result = {};
  94. int sign = 1;
  95. for (size_t j = 0; j < N; ++j) {
  96. result += sign * m_elements[0][j] * first_minor(0, j);
  97. sign *= -1;
  98. }
  99. return result;
  100. }
  101. }
  102. constexpr T first_minor(size_t skip_row, size_t skip_column) const
  103. {
  104. static_assert(N > 1);
  105. VERIFY(skip_row < N);
  106. VERIFY(skip_column < N);
  107. Matrix<N - 1, T> first_minor;
  108. constexpr auto new_size = N - 1;
  109. size_t k = 0;
  110. for (size_t i = 0; i < N; ++i) {
  111. for (size_t j = 0; j < N; ++j) {
  112. if (i == skip_row || j == skip_column)
  113. continue;
  114. first_minor.elements()[k / new_size][k % new_size] = m_elements[i][j];
  115. ++k;
  116. }
  117. }
  118. return first_minor.determinant();
  119. }
  120. constexpr static Matrix identity()
  121. {
  122. Matrix result;
  123. for (size_t i = 0; i < N; ++i) {
  124. for (size_t j = 0; j < N; ++j) {
  125. if (i == j)
  126. result.m_elements[i][j] = 1;
  127. else
  128. result.m_elements[i][j] = 0;
  129. }
  130. }
  131. return result;
  132. }
  133. constexpr Matrix inverse() const
  134. {
  135. auto det = determinant();
  136. VERIFY(det != 0);
  137. return adjugate() / det;
  138. }
  139. constexpr Matrix transpose() const
  140. {
  141. Matrix result;
  142. for (size_t i = 0; i < N; ++i) {
  143. for (size_t j = 0; j < N; ++j) {
  144. result.m_elements[i][j] = m_elements[j][i];
  145. }
  146. }
  147. return result;
  148. }
  149. private:
  150. T m_elements[N][N];
  151. };
  152. }
  153. using Gfx::Matrix;