ladybird/Userland/Libraries/LibSQL/AST/Expression.cpp
Timothy Flynn f3c6cb40d7 LibSQL: Convert SQL expression evaluation to use ResultOr
Instead of setting an error in the execution context, we can directly
return that error or the successful value. This lets all callers, who
were already TRY-capable, simply TRY the expression evaluation.
2022-02-10 23:11:13 +01:00

249 lines
9 KiB
C++

/*
* Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <LibRegex/Regex.h>
#include <LibSQL/AST/AST.h>
#include <LibSQL/Database.h>
namespace SQL::AST {
static const String s_posix_basic_metacharacters = ".^$*[]+\\";
ResultOr<Value> Expression::evaluate(ExecutionContext&) const
{
return Value::null();
}
ResultOr<Value> NumericLiteral::evaluate(ExecutionContext&) const
{
Value ret(SQLType::Float);
ret = value();
return ret;
}
ResultOr<Value> StringLiteral::evaluate(ExecutionContext&) const
{
Value ret(SQLType::Text);
ret = value();
return ret;
}
ResultOr<Value> NullLiteral::evaluate(ExecutionContext&) const
{
return Value::null();
}
ResultOr<Value> NestedExpression::evaluate(ExecutionContext& context) const
{
return expression()->evaluate(context);
}
ResultOr<Value> ChainedExpression::evaluate(ExecutionContext& context) const
{
Value ret(SQLType::Tuple);
Vector<Value> values;
for (auto& expression : expressions())
values.append(TRY(expression.evaluate(context)));
ret = values;
return ret;
}
ResultOr<Value> BinaryOperatorExpression::evaluate(ExecutionContext& context) const
{
Value lhs_value = TRY(lhs()->evaluate(context));
Value rhs_value = TRY(rhs()->evaluate(context));
switch (type()) {
case BinaryOperator::Concatenate: {
if (lhs_value.type() != SQLType::Text)
return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
AK::StringBuilder builder;
builder.append(lhs_value.to_string());
builder.append(rhs_value.to_string());
return Value(builder.to_string());
}
case BinaryOperator::Multiplication:
return lhs_value.multiply(rhs_value);
case BinaryOperator::Division:
return lhs_value.divide(rhs_value);
case BinaryOperator::Modulo:
return lhs_value.modulo(rhs_value);
case BinaryOperator::Plus:
return lhs_value.add(rhs_value);
case BinaryOperator::Minus:
return lhs_value.subtract(rhs_value);
case BinaryOperator::ShiftLeft:
return lhs_value.shift_left(rhs_value);
case BinaryOperator::ShiftRight:
return lhs_value.shift_right(rhs_value);
case BinaryOperator::BitwiseAnd:
return lhs_value.bitwise_and(rhs_value);
case BinaryOperator::BitwiseOr:
return lhs_value.bitwise_or(rhs_value);
case BinaryOperator::LessThan:
return Value(lhs_value.compare(rhs_value) < 0);
case BinaryOperator::LessThanEquals:
return Value(lhs_value.compare(rhs_value) <= 0);
case BinaryOperator::GreaterThan:
return Value(lhs_value.compare(rhs_value) > 0);
case BinaryOperator::GreaterThanEquals:
return Value(lhs_value.compare(rhs_value) >= 0);
case BinaryOperator::Equals:
return Value(lhs_value.compare(rhs_value) == 0);
case BinaryOperator::NotEquals:
return Value(lhs_value.compare(rhs_value) != 0);
case BinaryOperator::And: {
auto lhs_bool_maybe = lhs_value.to_bool();
auto rhs_bool_maybe = rhs_value.to_bool();
if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value())
return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
return Value(lhs_bool_maybe.release_value() && rhs_bool_maybe.release_value());
}
case BinaryOperator::Or: {
auto lhs_bool_maybe = lhs_value.to_bool();
auto rhs_bool_maybe = rhs_value.to_bool();
if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value())
return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) };
return Value(lhs_bool_maybe.release_value() || rhs_bool_maybe.release_value());
}
default:
VERIFY_NOT_REACHED();
}
}
ResultOr<Value> UnaryOperatorExpression::evaluate(ExecutionContext& context) const
{
Value expression_value = TRY(NestedExpression::evaluate(context));
switch (type()) {
case UnaryOperator::Plus:
if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float)
return expression_value;
return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
case UnaryOperator::Minus:
if (expression_value.type() == SQLType::Integer) {
expression_value = -int(expression_value);
return expression_value;
}
if (expression_value.type() == SQLType::Float) {
expression_value = -double(expression_value);
return expression_value;
}
return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) };
case UnaryOperator::Not:
if (expression_value.type() == SQLType::Boolean) {
expression_value = !bool(expression_value);
return expression_value;
}
return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) };
case UnaryOperator::BitwiseNot:
if (expression_value.type() == SQLType::Integer) {
expression_value = ~u32(expression_value);
return expression_value;
}
return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) };
default:
VERIFY_NOT_REACHED();
}
}
ResultOr<Value> ColumnNameExpression::evaluate(ExecutionContext& context) const
{
if (!context.current_row)
return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() };
auto& descriptor = *context.current_row->descriptor();
VERIFY(context.current_row->size() == descriptor.size());
Optional<size_t> index_in_row;
for (auto ix = 0u; ix < context.current_row->size(); ix++) {
auto& column_descriptor = descriptor[ix];
if (!table_name().is_empty() && column_descriptor.table != table_name())
continue;
if (column_descriptor.name == column_name()) {
if (index_in_row.has_value())
return Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() };
index_in_row = ix;
}
}
if (index_in_row.has_value())
return (*context.current_row)[index_in_row.value()];
return Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() };
}
ResultOr<Value> MatchExpression::evaluate(ExecutionContext& context) const
{
switch (type()) {
case MatchOperator::Like: {
Value lhs_value = TRY(lhs()->evaluate(context));
Value rhs_value = TRY(rhs()->evaluate(context));
char escape_char = '\0';
if (escape()) {
auto escape_str = TRY(escape()->evaluate(context)).to_string();
if (escape_str.length() != 1)
return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" };
escape_char = escape_str[0];
}
// Compile the pattern into a simple regex.
// https://sqlite.org/lang_expr.html#the_like_glob_regexp_and_match_operators
bool escaped = false;
AK::StringBuilder builder;
builder.append('^');
for (auto c : rhs_value.to_string()) {
if (escape() && c == escape_char && !escaped) {
escaped = true;
} else if (s_posix_basic_metacharacters.contains(c)) {
escaped = false;
builder.append('\\');
builder.append(c);
} else if (c == '_' && !escaped) {
builder.append('.');
} else if (c == '%' && !escaped) {
builder.append(".*");
} else {
escaped = false;
builder.append(c);
}
}
builder.append('$');
// FIXME: We should probably cache this regex.
auto regex = Regex<PosixBasic>(builder.build());
auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);
return Value(invert_expression() ? !result.success : result.success);
}
case MatchOperator::Regexp: {
Value lhs_value = TRY(lhs()->evaluate(context));
Value rhs_value = TRY(rhs()->evaluate(context));
auto regex = Regex<PosixExtended>(rhs_value.to_string());
auto err = regex.parser_result.error;
if (err != regex::Error::NoError) {
StringBuilder builder;
builder.append("Regular expression: ");
builder.append(get_error_string(err));
return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() };
}
auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode);
return Value(invert_expression() ? !result.success : result.success);
}
case MatchOperator::Glob:
case MatchOperator::Match:
default:
VERIFY_NOT_REACHED();
}
return Value::null();
}
}