Skip to content

Commit f842411

Browse files
authored
feat: implement reference visitor (#491)
1 parent a29355d commit f842411

3 files changed

Lines changed: 291 additions & 6 deletions

File tree

src/iceberg/expression/binder.cc

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
#include "iceberg/expression/binder.h"
2121

22+
#include "iceberg/result.h"
23+
#include "iceberg/util/macros.h"
24+
2225
namespace iceberg {
2326

2427
Binder::Binder(const Schema& schema, bool case_sensitive)
@@ -54,30 +57,30 @@ Result<std::shared_ptr<Expression>> Binder::Or(
5457

5558
Result<std::shared_ptr<Expression>> Binder::Predicate(
5659
const std::shared_ptr<UnboundPredicate>& pred) {
57-
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
60+
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
5861
return pred->Bind(schema_, case_sensitive_);
5962
}
6063

6164
Result<std::shared_ptr<Expression>> Binder::Predicate(
6265
const std::shared_ptr<BoundPredicate>& pred) {
63-
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
66+
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
6467
return InvalidExpression("Found already bound predicate: {}", pred->ToString());
6568
}
6669

6770
Result<std::shared_ptr<Expression>> Binder::Aggregate(
6871
const std::shared_ptr<BoundAggregate>& aggregate) {
69-
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
72+
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
7073
return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString());
7174
}
7275

7376
Result<std::shared_ptr<Expression>> Binder::Aggregate(
7477
const std::shared_ptr<UnboundAggregate>& aggregate) {
75-
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
78+
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
7679
return aggregate->Bind(schema_, case_sensitive_);
7780
}
7881

7982
Result<bool> IsBoundVisitor::IsBound(const std::shared_ptr<Expression>& expr) {
80-
ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null");
83+
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
8184
IsBoundVisitor visitor;
8285
return Visit<bool, IsBoundVisitor>(expr, visitor);
8386
}
@@ -113,4 +116,54 @@ Result<bool> IsBoundVisitor::Aggregate(
113116
return false;
114117
}
115118

119+
Result<std::unordered_set<int32_t>> ReferenceVisitor::GetReferencedFieldIds(
120+
const std::shared_ptr<Expression>& expr) {
121+
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
122+
ReferenceVisitor visitor;
123+
return Visit<FieldIdsSetRef, ReferenceVisitor>(expr, visitor);
124+
}
125+
126+
Result<FieldIdsSetRef> ReferenceVisitor::AlwaysTrue() { return referenced_field_ids_; }
127+
128+
Result<FieldIdsSetRef> ReferenceVisitor::AlwaysFalse() { return referenced_field_ids_; }
129+
130+
Result<FieldIdsSetRef> ReferenceVisitor::Not(
131+
[[maybe_unused]] const FieldIdsSetRef& child_result) {
132+
return referenced_field_ids_;
133+
}
134+
135+
Result<FieldIdsSetRef> ReferenceVisitor::And(
136+
[[maybe_unused]] const FieldIdsSetRef& left_result,
137+
[[maybe_unused]] const FieldIdsSetRef& right_result) {
138+
return referenced_field_ids_;
139+
}
140+
141+
Result<FieldIdsSetRef> ReferenceVisitor::Or(
142+
[[maybe_unused]] const FieldIdsSetRef& left_result,
143+
[[maybe_unused]] const FieldIdsSetRef& right_result) {
144+
return referenced_field_ids_;
145+
}
146+
147+
Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
148+
const std::shared_ptr<BoundPredicate>& pred) {
149+
referenced_field_ids_.insert(pred->reference()->field_id());
150+
return referenced_field_ids_;
151+
}
152+
153+
Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
154+
[[maybe_unused]] const std::shared_ptr<UnboundPredicate>& pred) {
155+
return InvalidExpression("Cannot get referenced field IDs from unbound predicate");
156+
}
157+
158+
Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
159+
const std::shared_ptr<BoundAggregate>& aggregate) {
160+
referenced_field_ids_.insert(aggregate->reference()->field_id());
161+
return referenced_field_ids_;
162+
}
163+
164+
Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
165+
[[maybe_unused]] const std::shared_ptr<UnboundAggregate>& aggregate) {
166+
return InvalidExpression("Cannot get referenced field IDs from unbound aggregate");
167+
}
168+
116169
} // namespace iceberg

src/iceberg/expression/binder.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
/// \file iceberg/expression/binder.h
2323
/// Bind an expression to a schema.
2424

25+
#include <functional>
26+
#include <unordered_set>
27+
2528
#include "iceberg/expression/expression_visitor.h"
2629

2730
namespace iceberg {
@@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor<bool> {
7376
Result<bool> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) override;
7477
};
7578

76-
// TODO(gangwu): add the Java parity `ReferenceVisitor`
79+
using FieldIdsSetRef = std::reference_wrapper<std::unordered_set<int32_t>>;
80+
81+
/// \brief Visitor to collect referenced field IDs from an expression.
82+
class ICEBERG_EXPORT ReferenceVisitor : public ExpressionVisitor<FieldIdsSetRef> {
83+
public:
84+
static Result<std::unordered_set<int32_t>> GetReferencedFieldIds(
85+
const std::shared_ptr<Expression>& expr);
86+
87+
Result<FieldIdsSetRef> AlwaysTrue() override;
88+
Result<FieldIdsSetRef> AlwaysFalse() override;
89+
Result<FieldIdsSetRef> Not(const FieldIdsSetRef& child_result) override;
90+
Result<FieldIdsSetRef> And(const FieldIdsSetRef& left_result,
91+
const FieldIdsSetRef& right_result) override;
92+
Result<FieldIdsSetRef> Or(const FieldIdsSetRef& left_result,
93+
const FieldIdsSetRef& right_result) override;
94+
Result<FieldIdsSetRef> Predicate(const std::shared_ptr<BoundPredicate>& pred) override;
95+
Result<FieldIdsSetRef> Predicate(
96+
const std::shared_ptr<UnboundPredicate>& pred) override;
97+
Result<FieldIdsSetRef> Aggregate(
98+
const std::shared_ptr<BoundAggregate>& aggregate) override;
99+
Result<FieldIdsSetRef> Aggregate(
100+
const std::shared_ptr<UnboundAggregate>& aggregate) override;
101+
102+
private:
103+
std::unordered_set<int32_t> referenced_field_ids_;
104+
};
77105

78106
} // namespace iceberg

src/iceberg/test/expression_visitor_test.cc

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "iceberg/expression/binder.h"
2323
#include "iceberg/expression/expressions.h"
2424
#include "iceberg/expression/rewrite_not.h"
25+
#include "iceberg/result.h"
2526
#include "iceberg/schema.h"
2627
#include "iceberg/test/matchers.h"
2728
#include "iceberg/type.h"
@@ -505,4 +506,207 @@ TEST_F(RewriteNotTest, ComplexExpression) {
505506
EXPECT_EQ(rewritten->op(), Expression::Operation::kOr);
506507
}
507508

509+
class ReferenceVisitorTest : public ExpressionVisitorTest {};
510+
511+
TEST_F(ReferenceVisitorTest, Constants) {
512+
// Constants should have no referenced fields
513+
auto true_expr = Expressions::AlwaysTrue();
514+
ICEBERG_UNWRAP_OR_FAIL(auto refs_true,
515+
ReferenceVisitor::GetReferencedFieldIds(true_expr));
516+
EXPECT_TRUE(refs_true.empty());
517+
518+
auto false_expr = Expressions::AlwaysFalse();
519+
ICEBERG_UNWRAP_OR_FAIL(auto refs_false,
520+
ReferenceVisitor::GetReferencedFieldIds(false_expr));
521+
EXPECT_TRUE(refs_false.empty());
522+
}
523+
524+
TEST_F(ReferenceVisitorTest, UnboundPredicate) {
525+
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
526+
auto result = ReferenceVisitor::GetReferencedFieldIds(unbound_pred);
527+
EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
528+
EXPECT_THAT(result,
529+
HasErrorMessage("Cannot get referenced field IDs from unbound predicate"));
530+
}
531+
532+
TEST_F(ReferenceVisitorTest, BoundPredicate) {
533+
// Bound predicate should return the field ID
534+
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
535+
ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred));
536+
537+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_pred));
538+
EXPECT_EQ(refs.size(), 1);
539+
EXPECT_EQ(refs.count(2), 1); // name field has id=2
540+
}
541+
542+
TEST_F(ReferenceVisitorTest, MultiplePredicates) {
543+
// Test various predicates with different fields
544+
auto pred_age = Expressions::GreaterThan("age", Literal::Int(25));
545+
ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age));
546+
ICEBERG_UNWRAP_OR_FAIL(auto refs_age,
547+
ReferenceVisitor::GetReferencedFieldIds(bound_age));
548+
EXPECT_EQ(refs_age.size(), 1);
549+
EXPECT_EQ(refs_age.count(3), 1); // age field has id=3
550+
551+
auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0));
552+
ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary));
553+
ICEBERG_UNWRAP_OR_FAIL(auto refs_salary,
554+
ReferenceVisitor::GetReferencedFieldIds(bound_salary));
555+
EXPECT_EQ(refs_salary.size(), 1);
556+
EXPECT_EQ(refs_salary.count(4), 1); // salary field has id=4
557+
}
558+
559+
TEST_F(ReferenceVisitorTest, UnaryPredicates) {
560+
// Test unary predicates
561+
auto pred_is_null = Expressions::IsNull("name");
562+
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null));
563+
ICEBERG_UNWRAP_OR_FAIL(auto refs,
564+
ReferenceVisitor::GetReferencedFieldIds(bound_is_null));
565+
EXPECT_EQ(refs.size(), 1);
566+
EXPECT_EQ(refs.count(2), 1);
567+
568+
auto pred_is_nan = Expressions::IsNaN("salary");
569+
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan));
570+
ICEBERG_UNWRAP_OR_FAIL(auto refs_nan,
571+
ReferenceVisitor::GetReferencedFieldIds(bound_is_nan));
572+
EXPECT_EQ(refs_nan.size(), 1);
573+
EXPECT_EQ(refs_nan.count(4), 1);
574+
}
575+
576+
TEST_F(ReferenceVisitorTest, AndExpression) {
577+
// AND expression should return union of field IDs from both sides
578+
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
579+
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
580+
auto and_expr = Expressions::And(pred1, pred2);
581+
582+
ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
583+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));
584+
585+
EXPECT_EQ(refs.size(), 2);
586+
EXPECT_EQ(refs.count(2), 1); // name field
587+
EXPECT_EQ(refs.count(3), 1); // age field
588+
}
589+
590+
TEST_F(ReferenceVisitorTest, OrExpression) {
591+
// OR expression should return union of field IDs from both sides
592+
auto pred1 = Expressions::IsNull("salary");
593+
auto pred2 = Expressions::Equal("active", Literal::Boolean(true));
594+
auto or_expr = Expressions::Or(pred1, pred2);
595+
596+
ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr));
597+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_or));
598+
599+
EXPECT_EQ(refs.size(), 2);
600+
EXPECT_EQ(refs.count(4), 1); // salary field
601+
EXPECT_EQ(refs.count(5), 1); // active field
602+
}
603+
604+
TEST_F(ReferenceVisitorTest, NotExpression) {
605+
// NOT expression should return field IDs from its child
606+
auto pred = Expressions::Equal("name", Literal::String("Alice"));
607+
auto not_expr = Expressions::Not(pred);
608+
609+
ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr));
610+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_not));
611+
612+
EXPECT_EQ(refs.size(), 1);
613+
EXPECT_EQ(refs.count(2), 1); // name field
614+
}
615+
616+
TEST_F(ReferenceVisitorTest, ComplexNestedExpression) {
617+
// (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true)
618+
// Should reference fields: name(2), age(3), salary(4), active(5)
619+
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
620+
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
621+
auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0));
622+
auto pred4 = Expressions::Equal("active", Literal::Boolean(true));
623+
624+
auto and1 = Expressions::And(pred1, pred2);
625+
auto and2 = Expressions::And(pred3, pred4);
626+
auto complex_or = Expressions::Or(and1, and2);
627+
628+
ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or));
629+
ICEBERG_UNWRAP_OR_FAIL(auto refs,
630+
ReferenceVisitor::GetReferencedFieldIds(bound_complex));
631+
632+
EXPECT_EQ(refs.size(), 4);
633+
EXPECT_EQ(refs.count(2), 1); // name field
634+
EXPECT_EQ(refs.count(3), 1); // age field
635+
EXPECT_EQ(refs.count(4), 1); // salary field
636+
EXPECT_EQ(refs.count(5), 1); // active field
637+
}
638+
639+
TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) {
640+
// Multiple predicates referencing the same field
641+
// age > 25 AND age < 50
642+
auto pred1 = Expressions::GreaterThan("age", Literal::Int(25));
643+
auto pred2 = Expressions::LessThan("age", Literal::Int(50));
644+
auto and_expr = Expressions::And(pred1, pred2);
645+
646+
ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
647+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));
648+
649+
// Should only contain the field ID once (set semantics)
650+
EXPECT_EQ(refs.size(), 1);
651+
EXPECT_EQ(refs.count(3), 1); // age field
652+
}
653+
654+
TEST_F(ReferenceVisitorTest, SetPredicates) {
655+
// Test In predicate
656+
auto pred_in =
657+
Expressions::In("age", {Literal::Int(25), Literal::Int(30), Literal::Int(35)});
658+
ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in));
659+
ICEBERG_UNWRAP_OR_FAIL(auto refs_in, ReferenceVisitor::GetReferencedFieldIds(bound_in));
660+
661+
EXPECT_EQ(refs_in.size(), 1);
662+
EXPECT_EQ(refs_in.count(3), 1); // age field
663+
664+
// Test NotIn predicate
665+
auto pred_not_in =
666+
Expressions::NotIn("name", {Literal::String("Alice"), Literal::String("Bob")});
667+
ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in));
668+
ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in,
669+
ReferenceVisitor::GetReferencedFieldIds(bound_not_in));
670+
671+
EXPECT_EQ(refs_not_in.size(), 1);
672+
EXPECT_EQ(refs_not_in.count(2), 1); // name field
673+
}
674+
675+
TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) {
676+
auto bound_pred = Expressions::Equal("name", Literal::String("Alice"));
677+
ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred));
678+
auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25));
679+
auto mixed_and = Expressions::And(pred1, unbound_pred);
680+
681+
auto result = ReferenceVisitor::GetReferencedFieldIds(mixed_and);
682+
EXPECT_THAT(result, IsError(ErrorKind::kInvalidExpression));
683+
EXPECT_THAT(result,
684+
HasErrorMessage("Cannot get referenced field IDs from unbound predicate"));
685+
}
686+
687+
TEST_F(ReferenceVisitorTest, AllFields) {
688+
// Create expression referencing all fields in the schema
689+
auto pred1 = Expressions::NotNull("id");
690+
auto pred2 = Expressions::Equal("name", Literal::String("Test"));
691+
auto pred3 = Expressions::GreaterThan("age", Literal::Int(0));
692+
auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0));
693+
auto pred5 = Expressions::Equal("active", Literal::Boolean(true));
694+
695+
auto and1 = Expressions::And(pred1, pred2);
696+
auto and2 = Expressions::And(pred3, pred4);
697+
auto and3 = Expressions::And(and1, and2);
698+
auto all_fields = Expressions::And(and3, pred5);
699+
700+
ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields));
701+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_all));
702+
703+
// Should reference all 5 fields
704+
EXPECT_EQ(refs.size(), 4);
705+
EXPECT_EQ(refs.count(1), 0); // id field is optimized out
706+
EXPECT_EQ(refs.count(2), 1); // name field
707+
EXPECT_EQ(refs.count(3), 1); // age field
708+
EXPECT_EQ(refs.count(4), 1); // salary field
709+
EXPECT_EQ(refs.count(5), 1); // active field
710+
}
711+
508712
} // namespace iceberg

0 commit comments

Comments
 (0)