Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 98 additions & 34 deletions src/paimon/common/predicate/like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,79 +16,143 @@

#include "paimon/common/predicate/like.h"

#include <string>
#include <vector>

namespace paimon {

namespace {

/// Returns the byte length of a UTF-8 leading byte's code point.
/// Returns 1 for ASCII, 2-4 for multi-byte sequences, 1 for invalid bytes.
inline size_t Utf8CodePointLength(unsigned char leading_byte) {
if (leading_byte < 0x80) {
return 1;
}
if ((leading_byte & 0xE0) == 0xC0) {
return 2;
}
if ((leading_byte & 0xF0) == 0xE0) {
return 3;
}
if ((leading_byte & 0xF8) == 0xF0) {
return 4;
}
return 1; // invalid continuation byte, treat as single byte
}

inline bool IsJavaRegexLineTerminator(const std::string& code_point) {
return code_point == "\n" || code_point == "\r" || code_point == "\xC2\x85" ||
code_point == "\xE2\x80\xA8" || code_point == "\xE2\x80\xA9";
}

} // namespace

Result<bool> Like::TestString(const std::string& field, const std::string& pattern) const {
if (pattern.empty()) {
return field.empty();
}
std::vector<char> pat;

// Phase 1: Parse pattern with escape handling (Java-compatible).
// Only \_, \%, \\ are valid escape sequences.
std::vector<std::string> pat_chars; // each element is a literal string segment or wildcard
std::vector<bool> is_wild;
for (size_t i = 0; i < pattern.size(); ++i) {
if (pattern[i] == '\\' && i + 1 < pattern.size()) {
pat.push_back(pattern[i + 1]);

for (size_t i = 0; i < pattern.size();) {
if (pattern[i] == '\\') {
if (i + 1 >= pattern.size()) {
return Status::Invalid("Invalid escape sequence '" + pattern + "', " +
std::to_string(i));
}
char next_char = pattern[i + 1];
if (next_char != '_' && next_char != '%' && next_char != '\\') {
return Status::Invalid("Invalid escape sequence '" + pattern + "', " +
std::to_string(i));
}
pat_chars.emplace_back(std::string(1, next_char));
is_wild.push_back(false);
i += 2;
} else if (pattern[i] == '_' || pattern[i] == '%') {
pat_chars.emplace_back(std::string(1, pattern[i]));
is_wild.push_back(true);
++i;
} else {
char c = pattern[i];
pat.push_back(c);
is_wild.push_back(c == '_' || c == '%');
// Read one UTF-8 code point from pattern as a literal element.
size_t cp_len = Utf8CodePointLength(static_cast<unsigned char>(pattern[i]));
if (i + cp_len > pattern.size()) {
cp_len = 1;
}
pat_chars.push_back(pattern.substr(i, cp_len));
is_wild.push_back(false);
i += cp_len;
}
}
std::vector<char> simp_pat;

// Phase 2: Merge consecutive '%' wildcards.
std::vector<std::string> simp_pat;
std::vector<bool> simp_wild;
for (size_t i = 0; i < pat.size(); ++i) {
if (is_wild[i] && pat[i] == '%' && !simp_pat.empty() && simp_wild.back() &&
simp_pat.back() == '%') {
for (size_t i = 0; i < pat_chars.size(); ++i) {
if (is_wild[i] && pat_chars[i] == "%" && !simp_pat.empty() && simp_wild.back() &&
simp_pat.back() == "%") {
continue;
}
simp_pat.push_back(pat[i]);
simp_pat.push_back(pat_chars[i]);
simp_wild.push_back(is_wild[i]);
}
const size_t m = field.size();

// Phase 3: Decompose field into UTF-8 code points for character-level matching.
std::vector<std::string> field_chars;
for (size_t i = 0; i < field.size();) {
size_t cp_len = Utf8CodePointLength(static_cast<unsigned char>(field[i]));
if (i + cp_len > field.size()) {
cp_len = 1; // truncated sequence, treat byte as single char
}
field_chars.push_back(field.substr(i, cp_len));
i += cp_len;
}

const size_t m = field_chars.size();
const size_t n = simp_pat.size();
if (field.empty()) {
return n == 1 && simp_wild[0] && simp_pat[0] == '%';

if (m == 0) {
return n == 1 && simp_wild[0] && simp_pat[0] == "%";
}

// Quick reject: count minimum required characters (non-wildcard pattern elements).
size_t min_len = 0;
for (size_t i = 0; i < n; ++i) {
if (!simp_wild[i]) {
min_len++;
} else if (simp_pat[i] == "_") {
min_len++;
}
}
if (min_len > m) {
return false;
}
constexpr size_t STACK_LIMIT = 128;
std::unique_ptr<bool[]> dp_storage;
bool* dp;
if (n <= STACK_LIMIT) {
dp = static_cast<bool*>(alloca((n + 1) * sizeof(bool)));
} else {
dp_storage = std::make_unique<bool[]>(n + 1);
dp = dp_storage.get();
}
std::fill_n(dp, n + 1, false);

// Phase 4: DP matching at character (code point) level.
std::vector<bool> dp(n + 1, false);
dp[0] = true;
for (size_t j = 1; j <= n && simp_wild[j - 1] && simp_pat[j - 1] == '%'; ++j) {
for (size_t j = 1; j <= n && simp_wild[j - 1] && simp_pat[j - 1] == "%"; ++j) {
dp[j] = true;
}
const char* f = field.data();

for (size_t i = 0; i < m; ++i) {
const char sc = f[i];
const std::string& field_char = field_chars[i];
bool prev = dp[0];
dp[0] = false;
bool has_match = false;
for (size_t j = 1; j <= n; ++j) {
const bool temp = dp[j];
const char pc = simp_pat[j - 1];
const std::string& pc = simp_pat[j - 1];
const bool wild = simp_wild[j - 1];
if (wild && pc == '%') {
if (wild && pc == "%") {
dp[j] = dp[j - 1] || dp[j];
} else if (wild && pc == '_') {
dp[j] = prev;
} else if (wild && pc == "_") {
dp[j] = prev && !IsJavaRegexLineTerminator(field_char);
} else {
dp[j] = (pc == sc) ? prev : false;
dp[j] = (pc == field_char) ? prev : false;
}
has_match |= dp[j];
prev = temp;
Expand All @@ -97,6 +161,6 @@ Result<bool> Like::TestString(const std::string& field, const std::string& patte
return false;
}
}
return dp[n];
return static_cast<bool>(dp[n]);
}
} // namespace paimon
119 changes: 119 additions & 0 deletions src/paimon/common/predicate/predicate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,125 @@ TEST_F(PredicateTest, TestLikeLongPatternHeapAlloc) {
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({non_matching_field})).value());
}

TEST_F(PredicateTest, TestLikeInvalidEscapeSequence) {
auto arrow_schema = arrow::schema(arrow::FieldVector({arrow::field("f0", arrow::utf8())}));

// Trailing backslash is invalid (Java throws "Invalid escape sequence")
ASSERT_OK_AND_ASSIGN(auto predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "abc\\", 4)));
auto predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_NOK_WITH_MSG(predicate->Test(arrow_schema, CreateStringRow({"abc"})),
"Invalid escape sequence");

// Backslash followed by non-special char is invalid (only \_, \%, \\ are legal)
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "a\\bc", 4)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_NOK_WITH_MSG(predicate->Test(arrow_schema, CreateStringRow({"abc"})),
"Invalid escape sequence");

// \n is not a valid escape
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "a\\nf", 4)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_NOK_WITH_MSG(predicate->Test(arrow_schema, CreateStringRow({"anf"})),
"Invalid escape sequence");
}

TEST_F(PredicateTest, TestLikeEscapeBackslash) {
auto arrow_schema = arrow::schema(arrow::FieldVector({arrow::field("f0", arrow::utf8())}));

// \\\\ in C++ string literal = "\\" in the pattern = escaped backslash
ASSERT_OK_AND_ASSIGN(auto predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "a\\\\b", 4)));
auto predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
// Field "a\b" should match pattern "a\\b" (escaped backslash)
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"a\\b"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"axb"})).value());

// Escaped percent: "a\%b" matches literal "a%b"
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "a\\%b", 4)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"a%b"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"axb"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"axxb"})).value());
}

TEST_F(PredicateTest, TestLikeUtf8MultibyteUnderscore) {
auto arrow_schema = arrow::schema(arrow::FieldVector({arrow::field("f0", arrow::utf8())}));

// Single '_' should match one Unicode character, not one byte.
ASSERT_OK_AND_ASSIGN(auto predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "_", 1)));
auto predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"中"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"中文"})).value());

// "a_c" where _ matches one Chinese character
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "a_c", 3)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"a中c"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"a中文c"})).value());

// "___" should match exactly 3 Unicode characters
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "___", 3)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"中文字"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"中文"})).value());

// '%' should still work with multi-byte characters
std::string pattern_contains = std::string("%") + "中" + "%";
ASSERT_OK_AND_ASSIGN(
predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, pattern_contains.data(), pattern_contains.size())));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"hello中world"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"helloworld"})).value());
}

TEST_F(PredicateTest, TestLikeJavaRegexLineTerminatorSemantics) {
auto arrow_schema = arrow::schema(arrow::FieldVector({arrow::field("f0", arrow::utf8())}));

// Java regex '.' does not match line terminators, so '_' should not match them either.
ASSERT_OK_AND_ASSIGN(auto predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "_", 1)));
auto predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"\n"})).value());
ASSERT_FALSE(predicate->Test(arrow_schema, CreateStringRow({"\r"})).value());

// Java LIKE '%' uses (?s:.*), so it should still match line terminators.
ASSERT_OK_AND_ASSIGN(predicate_base,
PredicateBuilder::Like(
/*field_index=*/0, /*field_name=*/"f0", FieldType::STRING,
Literal(FieldType::STRING, "%", 1)));
predicate = std::dynamic_pointer_cast<PredicateFilter>(predicate_base);
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"\n"})).value());
ASSERT_TRUE(predicate->Test(arrow_schema, CreateStringRow({"\r"})).value());
}

TEST_F(PredicateTest, TestCompound) {
ASSERT_OK_AND_ASSIGN(
const auto startswith_predicate,
Expand Down
Loading