Insert.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. /*
  2. * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
  3. * Copyright (c) 2021, Mahmoud Mandour <ma.mandourr@gmail.com>
  4. *
  5. * SPDX-License-Identifier: BSD-2-Clause
  6. */
  7. #include <LibSQL/AST/AST.h>
  8. #include <LibSQL/Database.h>
  9. #include <LibSQL/Meta.h>
  10. #include <LibSQL/Row.h>
  11. namespace SQL::AST {
  12. static bool does_value_data_type_match(SQLType expected, SQLType actual)
  13. {
  14. if (actual == SQLType::Null) {
  15. return false;
  16. }
  17. if (expected == SQLType::Integer) {
  18. return actual == SQLType::Integer || actual == SQLType::Float;
  19. }
  20. return expected == actual;
  21. }
  22. RefPtr<SQLResult> Insert::execute(ExecutionContext& context) const
  23. {
  24. auto table_def = context.database->get_table(m_schema_name, m_table_name);
  25. if (!table_def) {
  26. auto schema_name = m_schema_name;
  27. if (schema_name.is_null() || schema_name.is_empty())
  28. schema_name = "default";
  29. return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::TableDoesNotExist, String::formatted("{}.{}", schema_name, m_table_name));
  30. }
  31. Row row(table_def);
  32. for (auto& column : m_column_names) {
  33. if (!row.has(column)) {
  34. return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::ColumnDoesNotExist, column);
  35. }
  36. }
  37. Vector<Row> inserted_rows;
  38. inserted_rows.ensure_capacity(m_chained_expressions.size());
  39. for (auto& row_expr : m_chained_expressions) {
  40. for (auto& column_def : table_def->columns()) {
  41. if (!m_column_names.contains_slow(column_def.name())) {
  42. row[column_def.name()] = column_def.default_value();
  43. }
  44. }
  45. auto row_value = row_expr.evaluate(context);
  46. VERIFY(row_value.type() == SQLType::Tuple);
  47. auto values = row_value.to_vector().value();
  48. if (m_column_names.size() == 0 && values.size() != row.size()) {
  49. return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, "");
  50. }
  51. for (auto ix = 0u; ix < values.size(); ix++) {
  52. auto input_value_type = values[ix].type();
  53. auto& tuple_descriptor = *row.descriptor();
  54. // In case of having column names, this must succeed since we checked for every column name for existence in the table.
  55. auto element_index = (m_column_names.size() == 0) ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index();
  56. auto element_type = tuple_descriptor[element_index].type;
  57. if (!does_value_data_type_match(element_type, input_value_type)) {
  58. return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name());
  59. }
  60. row[element_index] = values[ix];
  61. }
  62. inserted_rows.append(row);
  63. }
  64. for (auto& inserted_row : inserted_rows) {
  65. context.database->insert(inserted_row);
  66. }
  67. return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0);
  68. }
  69. }