Expression.cpp 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <LibRegex/Regex.h>
  7. #include <LibSQL/AST/AST.h>
  8. #include <LibSQL/Database.h>
  9. namespace SQL::AST {
  10. static const String s_posix_basic_metacharacters = ".^$*[]+\\";
  11. Value Expression::evaluate(ExecutionContext&) const
  12. {
  13. return Value::null();
  14. }
  15. Value NumericLiteral::evaluate(ExecutionContext& context) const
  16. {
  17. if (context.result->has_error())
  18. return Value::null();
  19. Value ret(SQLType::Float);
  20. ret = value();
  21. return ret;
  22. }
  23. Value StringLiteral::evaluate(ExecutionContext& context) const
  24. {
  25. if (context.result->has_error())
  26. return Value::null();
  27. Value ret(SQLType::Text);
  28. ret = value();
  29. return ret;
  30. }
  31. Value NullLiteral::evaluate(ExecutionContext&) const
  32. {
  33. return Value::null();
  34. }
  35. Value NestedExpression::evaluate(ExecutionContext& context) const
  36. {
  37. if (context.result->has_error())
  38. return Value::null();
  39. return expression()->evaluate(context);
  40. }
  41. Value ChainedExpression::evaluate(ExecutionContext& context) const
  42. {
  43. if (context.result->has_error())
  44. return Value::null();
  45. Value ret(SQLType::Tuple);
  46. Vector<Value> values;
  47. for (auto& expression : expressions()) {
  48. values.append(expression.evaluate(context));
  49. }
  50. ret = values;
  51. return ret;
  52. }
  53. Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const
  54. {
  55. if (context.result->has_error())
  56. return Value::null();
  57. Value lhs_value = lhs()->evaluate(context);
  58. Value rhs_value = rhs()->evaluate(context);
  59. switch (type()) {
  60. case BinaryOperator::Concatenate: {
  61. if (lhs_value.type() != SQLType::Text) {
  62. context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()));
  63. return Value::null();
  64. }
  65. AK::StringBuilder builder;
  66. builder.append(lhs_value.to_string());
  67. builder.append(rhs_value.to_string());
  68. return Value(builder.to_string());
  69. }
  70. case BinaryOperator::Multiplication:
  71. return lhs_value.multiply(rhs_value);
  72. case BinaryOperator::Division:
  73. return lhs_value.divide(rhs_value);
  74. case BinaryOperator::Modulo:
  75. return lhs_value.modulo(rhs_value);
  76. case BinaryOperator::Plus:
  77. return lhs_value.add(rhs_value);
  78. case BinaryOperator::Minus:
  79. return lhs_value.subtract(rhs_value);
  80. case BinaryOperator::ShiftLeft:
  81. return lhs_value.shift_left(rhs_value);
  82. case BinaryOperator::ShiftRight:
  83. return lhs_value.shift_right(rhs_value);
  84. case BinaryOperator::BitwiseAnd:
  85. return lhs_value.bitwise_and(rhs_value);
  86. case BinaryOperator::BitwiseOr:
  87. return lhs_value.bitwise_or(rhs_value);
  88. case BinaryOperator::LessThan:
  89. return Value(lhs_value.compare(rhs_value) < 0);
  90. case BinaryOperator::LessThanEquals:
  91. return Value(lhs_value.compare(rhs_value) <= 0);
  92. case BinaryOperator::GreaterThan:
  93. return Value(lhs_value.compare(rhs_value) > 0);
  94. case BinaryOperator::GreaterThanEquals:
  95. return Value(lhs_value.compare(rhs_value) >= 0);
  96. case BinaryOperator::Equals:
  97. return Value(lhs_value.compare(rhs_value) == 0);
  98. case BinaryOperator::NotEquals:
  99. return Value(lhs_value.compare(rhs_value) != 0);
  100. case BinaryOperator::And: {
  101. auto lhs_bool_maybe = lhs_value.to_bool();
  102. auto rhs_bool_maybe = rhs_value.to_bool();
  103. if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
  104. context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()));
  105. return Value::null();
  106. }
  107. return Value(lhs_bool_maybe.release_value() && rhs_bool_maybe.release_value());
  108. }
  109. case BinaryOperator::Or: {
  110. auto lhs_bool_maybe = lhs_value.to_bool();
  111. auto rhs_bool_maybe = rhs_value.to_bool();
  112. if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) {
  113. context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()));
  114. return Value::null();
  115. }
  116. return Value(lhs_bool_maybe.release_value() || rhs_bool_maybe.release_value());
  117. }
  118. default:
  119. VERIFY_NOT_REACHED();
  120. }
  121. }
  122. Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const
  123. {
  124. if (context.result->has_error())
  125. return Value::null();
  126. Value expression_value = NestedExpression::evaluate(context);
  127. switch (type()) {
  128. case UnaryOperator::Plus:
  129. if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
  130. return expression_value;
  131. context.result->set_error(SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()));
  132. return Value::null();
  133. case UnaryOperator::Minus:
  134. if (expression_value.type() == SQLType::Integer) {
  135. expression_value = -int(expression_value);
  136. return expression_value;
  137. }
  138. if (expression_value.type() == SQLType::Float) {
  139. expression_value = -double(expression_value);
  140. return expression_value;
  141. }
  142. context.result->set_error(SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()));
  143. return Value::null();
  144. case UnaryOperator::Not:
  145. if (expression_value.type() == SQLType::Boolean) {
  146. expression_value = !bool(expression_value);
  147. return expression_value;
  148. }
  149. context.result->set_error(SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()));
  150. return Value::null();
  151. case UnaryOperator::BitwiseNot:
  152. if (expression_value.type() == SQLType::Integer) {
  153. expression_value = ~u32(expression_value);
  154. return expression_value;
  155. }
  156. context.result->set_error(SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()));
  157. return Value::null();
  158. }
  159. VERIFY_NOT_REACHED();
  160. }
  161. Value ColumnNameExpression::evaluate(ExecutionContext& context) const
  162. {
  163. if (!context.current_row) {
  164. context.result->set_error(SQLErrorCode::SyntaxError, column_name());
  165. return Value::null();
  166. }
  167. auto& descriptor = *context.current_row->descriptor();
  168. VERIFY(context.current_row->size() == descriptor.size());
  169. Optional<size_t> index_in_row;
  170. for (auto ix = 0u; ix < context.current_row->size(); ix++) {
  171. auto& column_descriptor = descriptor[ix];
  172. if (!table_name().is_empty() && column_descriptor.table != table_name())
  173. continue;
  174. if (column_descriptor.name == column_name()) {
  175. if (index_in_row.has_value()) {
  176. context.result->set_error(SQLErrorCode::AmbiguousColumnName, column_name());
  177. return Value::null();
  178. }
  179. index_in_row = ix;
  180. }
  181. }
  182. if (index_in_row.has_value())
  183. return (*context.current_row)[index_in_row.value()];
  184. context.result->set_error(SQLErrorCode::ColumnDoesNotExist, column_name());
  185. return Value::null();
  186. }
  187. Value MatchExpression::evaluate(ExecutionContext& context) const
  188. {
  189. if (context.result->has_error())
  190. return Value::null();
  191. switch (type()) {
  192. case MatchOperator::Like: {
  193. Value lhs_value = lhs()->evaluate(context);
  194. Value rhs_value = rhs()->evaluate(context);
  195. char escape_char = '\0';
  196. if (escape()) {
  197. auto escape_str = escape()->evaluate(context).to_string();
  198. if (escape_str.length() != 1) {
  199. context.result->set_error(SQLErrorCode::SyntaxError, "ESCAPE should be a single character");
  200. return Value::null();
  201. }
  202. escape_char = escape_str[0];
  203. }
  204. // Compile the pattern into a simple regex.
  205. // https://sqlite.org/lang_expr.html#the_like_glob_regexp_and_match_operators
  206. bool escaped = false;
  207. AK::StringBuilder builder;
  208. builder.append('^');
  209. for (auto c : rhs_value.to_string()) {
  210. if (escape() && c == escape_char && !escaped) {
  211. escaped = true;
  212. } else if (s_posix_basic_metacharacters.contains(c)) {
  213. escaped = false;
  214. builder.append('\\');
  215. builder.append(c);
  216. } else if (c == '_' && !escaped) {
  217. builder.append('.');
  218. } else if (c == '%' && !escaped) {
  219. builder.append(".*");
  220. } else {
  221. escaped = false;
  222. builder.append(c);
  223. }
  224. }
  225. builder.append('$');
  226. // FIXME: We should probably cache this regex.
  227. auto regex = Regex<PosixBasic>(builder.build());
  228. auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);
  229. return Value(invert_expression() ? !result.success : result.success);
  230. }
  231. case MatchOperator::Regexp: {
  232. Value lhs_value = lhs()->evaluate(context);
  233. Value rhs_value = rhs()->evaluate(context);
  234. auto regex = Regex<PosixExtended>(rhs_value.to_string());
  235. auto err = regex.parser_result.error;
  236. if (err != regex::Error::NoError) {
  237. StringBuilder builder;
  238. builder.append("Regular expression: ");
  239. builder.append(get_error_string(err));
  240. context.result->set_error(SQLErrorCode::SyntaxError, builder.build());
  241. return Value(false);
  242. }
  243. auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);
  244. return Value(invert_expression() ? !result.success : result.success);
  245. }
  246. case MatchOperator::Glob:
  247. case MatchOperator::Match:
  248. default:
  249. VERIFY_NOT_REACHED();
  250. }
  251. return Value::null();
  252. }
  253. }