Update.cpp 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. /*
  2. * Copyright (c) 2022, Tim Flynn <trflynn89@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <LibSQL/AST/AST.h>
  7. #include <LibSQL/Database.h>
  8. #include <LibSQL/Meta.h>
  9. #include <LibSQL/Row.h>
  10. namespace SQL::AST {
  11. ResultOr<ResultSet> Update::execute(ExecutionContext& context) const
  12. {
  13. auto const& schema_name = m_qualified_table_name->schema_name();
  14. auto const& table_name = m_qualified_table_name->table_name();
  15. auto table_def = TRY(context.database->get_table(schema_name, table_name));
  16. Vector<Row> matched_rows;
  17. for (auto& table_row : TRY(context.database->select_all(*table_def))) {
  18. context.current_row = &table_row;
  19. if (auto const& where_clause = this->where_clause()) {
  20. auto where_result = TRY(where_clause->evaluate(context)).to_bool();
  21. if (!where_result.has_value() || !where_result.value())
  22. continue;
  23. }
  24. TRY(matched_rows.try_append(move(table_row)));
  25. }
  26. ResultSet result { SQLCommand::Update };
  27. for (auto& update_column : m_update_columns) {
  28. auto row_value = TRY(update_column.expression->evaluate(context));
  29. for (auto& table_row : matched_rows) {
  30. auto& row_descriptor = *table_row.descriptor();
  31. for (auto const& column_name : update_column.column_names) {
  32. if (!table_row.has(column_name))
  33. return Result { SQLCommand::Update, SQLErrorCode::ColumnDoesNotExist, column_name };
  34. auto column_index = row_descriptor.find_if([&](auto element) { return element.name == column_name; }).index();
  35. auto column_type = row_descriptor[column_index].type;
  36. if (!row_value.is_type_compatible_with(column_type))
  37. return Result { SQLCommand::Update, SQLErrorCode::InvalidValueType, column_name };
  38. table_row[column_index] = row_value;
  39. }
  40. TRY(context.database->update(table_row));
  41. result.insert_row(table_row, {});
  42. }
  43. }
  44. return result;
  45. }
  46. }