From d6dee8c0e838da8cf43bdac5223665668a2d9980 Mon Sep 17 00:00:00 2001 From: Timothy Flynn Date: Fri, 3 Feb 2023 10:33:10 -0500 Subject: [PATCH] LibSQL+Userland: Pass SQL IPC results to clients in a structure SQLClient exists as a wrapper around SQL IPC to provide a bit friendlier interface for clients to deal with. Though right now, it mostly forwards values as-is from IPC to the clients. This makes it a bit verbose to add values to IPC responses, as we then have to add it to the callbacks used by all clients. It's also a bit confusing seeing a sea of "auto" as the parameter types for these callbacks. This patch moves these response values to named structures instead. This will allow adding values without needing to simultaneously update all clients. We can then separately handle the new values in interested clients only. --- Userland/Applications/Browser/Database.cpp | 22 +++--- Userland/Applications/Browser/Database.h | 6 ++ Userland/DevTools/SQLStudio/MainWidget.cpp | 14 ++-- Userland/Libraries/LibSQL/SQLClient.cpp | 83 +++++++++++++++------- Userland/Libraries/LibSQL/SQLClient.h | 42 +++++++++-- Userland/Utilities/sql.cpp | 22 +++--- 6 files changed, 127 insertions(+), 62 deletions(-) diff --git a/Userland/Applications/Browser/Database.cpp b/Userland/Applications/Browser/Database.cpp index e74ab95da39..6931be0b281 100644 --- a/Userland/Applications/Browser/Database.cpp +++ b/Userland/Applications/Browser/Database.cpp @@ -30,11 +30,11 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c : m_sql_client(move(sql_client)) , m_connection_id(connection_id) { - m_sql_client->on_execution_success = [this](auto statement_id, auto execution_id, auto has_results, auto, auto, auto) { - if (has_results) + m_sql_client->on_execution_success = [this](auto result) { + if (result.has_results) return; - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); @@ -43,15 +43,15 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c } }; - m_sql_client->on_next_result = [this](auto statement_id, auto execution_id, auto row) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_next_result = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { if (it->value.on_result) - it->value.on_result(row); + it->value.on_result(result.values); } }; - m_sql_client->on_results_exhausted = [this](auto statement_id, auto execution_id, auto) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_results_exhausted = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); @@ -60,13 +60,13 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c } }; - m_sql_client->on_execution_error = [this](auto statement_id, auto execution_id, auto, auto const& message) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_execution_error = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); if (in_progress_statement.on_error) - in_progress_statement.on_error(message); + in_progress_statement.on_error(result.error_message); } }; } diff --git a/Userland/Applications/Browser/Database.h b/Userland/Applications/Browser/Database.h index 74eb58bac37..bbe05e198ca 100644 --- a/Userland/Applications/Browser/Database.h +++ b/Userland/Applications/Browser/Database.h @@ -68,6 +68,12 @@ private: Database(NonnullRefPtr sql_client, SQL::ConnectionID connection_id); void execute_statement(SQL::StatementID statement_id, Vector placeholder_values, PendingExecution pending_execution); + template + auto find_pending_execution(ResultData const& result_data) + { + return m_pending_executions.find({ result_data.statement_id, result_data.execution_id }); + } + NonnullRefPtr m_sql_client; SQL::ConnectionID m_connection_id { 0 }; diff --git a/Userland/DevTools/SQLStudio/MainWidget.cpp b/Userland/DevTools/SQLStudio/MainWidget.cpp index 4de878b9cfd..08cfa90d4c3 100644 --- a/Userland/DevTools/SQLStudio/MainWidget.cpp +++ b/Userland/DevTools/SQLStudio/MainWidget.cpp @@ -253,23 +253,23 @@ MainWidget::MainWidget() }; m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors(); - m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto, auto) { + m_sql_client->on_execution_success = [this](auto) { read_next_sql_statement_of_editor(); }; - m_sql_client->on_execution_error = [this](auto, auto, auto, auto message) { + m_sql_client->on_execution_error = [this](auto result) { auto* editor = active_editor(); VERIFY(editor); - GUI::MessageBox::show_error(window(), DeprecatedString::formatted("Error executing {}\n{}", editor->path(), message)); + GUI::MessageBox::show_error(window(), DeprecatedString::formatted("Error executing {}\n{}", editor->path(), result.error_message)); }; - m_sql_client->on_next_result = [this](auto, auto, auto row) { + m_sql_client->on_next_result = [this](auto result) { m_results.append({}); - m_results.last().ensure_capacity(row.size()); + m_results.last().ensure_capacity(result.values.size()); - for (auto const& value : row) + for (auto const& value : result.values) m_results.last().unchecked_append(value.to_deprecated_string()); }; - m_sql_client->on_results_exhausted = [this](auto, auto, auto) { + m_sql_client->on_results_exhausted = [this](auto) { if (m_results.size() == 0) return; if (m_results[0].size() == 0) diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index bab488ef0fe..1b0705ed165 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -154,45 +154,74 @@ ErrorOr> SQLClient::launch_server_and_create_client(Vec #endif -void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) -{ - if (on_execution_error) - on_execution_error(statement_id, execution_id, code, message); - else - warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); -} - void SQLClient::execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) { - if (on_execution_success) - on_execution_success(statement_id, execution_id, has_results, created, updated, deleted); - else + if (!on_execution_success) { outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); -} - -void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) -{ - if (on_next_result) { - on_next_result(statement_id, execution_id, row); return; } - bool first = true; - for (auto& column : row) { - if (!first) - out(", "); - out("\"{}\"", column); - first = false; + ExecutionSuccess success { + .statement_id = statement_id, + .execution_id = execution_id, + .has_results = has_results, + .rows_created = created, + .rows_updated = updated, + .rows_deleted = deleted, + }; + + on_execution_success(move(success)); +} + +void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) +{ + if (!on_execution_error) { + warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); + return; } - outln(); + + ExecutionError error { + .statement_id = statement_id, + .execution_id = execution_id, + .error_code = code, + .error_message = move(const_cast(message)), + }; + + on_execution_error(move(error)); +} + +void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) +{ + if (!on_next_result) { + StringBuilder builder; + builder.join(", "sv, row, "\"{}\""sv); + outln("{}", builder.string_view()); + return; + } + + ExecutionResult result { + .statement_id = statement_id, + .execution_id = execution_id, + .values = move(const_cast&>(row)), + }; + + on_next_result(move(result)); } void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) { - if (on_results_exhausted) - on_results_exhausted(statement_id, execution_id, total_rows); - else + if (!on_results_exhausted) { outln("{} total row(s)", total_rows); + return; + } + + ExecutionComplete success { + .statement_id = statement_id, + .execution_id = execution_id, + .total_rows = total_rows, + }; + + on_results_exhausted(move(success)); } } diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h index 76c4a94f7eb..97c37a7ea87 100644 --- a/Userland/Libraries/LibSQL/SQLClient.h +++ b/Userland/Libraries/LibSQL/SQLClient.h @@ -15,6 +15,38 @@ namespace SQL { +struct ExecutionSuccess { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + bool has_results { false }; + size_t rows_created { 0 }; + size_t rows_updated { 0 }; + size_t rows_deleted { 0 }; +}; + +struct ExecutionError { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + SQLErrorCode error_code; + DeprecatedString error_message; +}; + +struct ExecutionResult { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + Vector values; +}; + +struct ExecutionComplete { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + size_t total_rows { 0 }; +}; + class SQLClient : public IPC::ConnectionToServer , public SQLClientEndpoint { @@ -27,10 +59,10 @@ public: virtual ~SQLClient() = default; - Function on_execution_error; - Function on_execution_success; - Function)> on_next_result; - Function on_results_exhausted; + Function on_execution_success; + Function on_execution_error; + Function on_next_result; + Function on_results_exhausted; private: explicit SQLClient(NonnullOwnPtr socket) @@ -39,9 +71,9 @@ private: } virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override; + virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; virtual void next_result(u64 statement_id, u64 execution_id, Vector const&) override; virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override; - virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; }; } diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp index b263973b01e..d4d630cbf2e 100644 --- a/Userland/Utilities/sql.cpp +++ b/Userland/Utilities/sql.cpp @@ -76,28 +76,26 @@ public: m_editor->set_prompt(prompt_for_level(open_indents)); }; - m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto created, auto updated, auto deleted) { - if (updated != 0 || created != 0 || deleted != 0) { - outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); - } - if (!has_results) { + m_sql_client->on_execution_success = [this](auto result) { + if (result.rows_updated != 0 || result.rows_created != 0 || result.rows_deleted != 0) + outln("{} row(s) created, {} updated, {} deleted", result.rows_created, result.rows_updated, result.rows_deleted); + if (!result.has_results) read_sql(); - } }; - m_sql_client->on_next_result = [](auto, auto, auto row) { + m_sql_client->on_next_result = [](auto result) { StringBuilder builder; - builder.join(", "sv, row); + builder.join(", "sv, result.values); outln("{}", builder.to_deprecated_string()); }; - m_sql_client->on_results_exhausted = [this](auto, auto, auto total_rows) { - outln("{} row(s)", total_rows); + m_sql_client->on_results_exhausted = [this](auto result) { + outln("{} row(s)", result.total_rows); read_sql(); }; - m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) { - outln("\033[33;1mExecution error:\033[0m {}", message); + m_sql_client->on_execution_error = [this](auto result) { + outln("\033[33;1mExecution error:\033[0m {}", result.error_message); read_sql(); };