Skip to content

Commit f2674e8

Browse files
committed
feat: implement reference visitor
1 parent e5eb6e0 commit f2674e8

File tree

3 files changed

+292
-6
lines changed

3 files changed

+292
-6
lines changed

src/iceberg/expression/binder.cc

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
#include "iceberg/expression/binder.h"
2121

22+
#include "iceberg/util/macros.h"
23+
2224
namespace iceberg {
2325

2426
Binder::Binder(const Schema& schema, bool case_sensitive)
@@ -54,30 +56,30 @@ Result<std::shared_ptr<Expression>> Binder::Or(
5456

5557
Result<std::shared_ptr<Expression>> Binder::Predicate(
5658
const std::shared_ptr<UnboundPredicate>& pred) {
57-
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
59+
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
5860
return pred->Bind(schema_, case_sensitive_);
5961
}
6062

6163
Result<std::shared_ptr<Expression>> Binder::Predicate(
6264
const std::shared_ptr<BoundPredicate>& pred) {
63-
ICEBERG_DCHECK(pred != nullptr, "Predicate cannot be null");
65+
ICEBERG_PRECHECK(pred != nullptr, "Predicate cannot be null");
6466
return InvalidExpression("Found already bound predicate: {}", pred->ToString());
6567
}
6668

6769
Result<std::shared_ptr<Expression>> Binder::Aggregate(
6870
const std::shared_ptr<BoundAggregate>& aggregate) {
69-
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
71+
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
7072
return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString());
7173
}
7274

7375
Result<std::shared_ptr<Expression>> Binder::Aggregate(
7476
const std::shared_ptr<UnboundAggregate>& aggregate) {
75-
ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null");
77+
ICEBERG_PRECHECK(aggregate != nullptr, "Aggregate cannot be null");
7678
return aggregate->Bind(schema_, case_sensitive_);
7779
}
7880

7981
Result<bool> IsBoundVisitor::IsBound(const std::shared_ptr<Expression>& expr) {
80-
ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null");
82+
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
8183
IsBoundVisitor visitor;
8284
return Visit<bool, IsBoundVisitor>(expr, visitor);
8385
}
@@ -113,4 +115,54 @@ Result<bool> IsBoundVisitor::Aggregate(
113115
return false;
114116
}
115117

118+
Result<std::unordered_set<int32_t>> ReferenceVisitor::GetReferencedFieldIds(
119+
const std::shared_ptr<Expression>& expr) {
120+
ICEBERG_PRECHECK(expr != nullptr, "Expression cannot be null");
121+
ReferenceVisitor visitor;
122+
return Visit<FieldIdsSetRef, ReferenceVisitor>(expr, visitor);
123+
}
124+
125+
Result<FieldIdsSetRef> ReferenceVisitor::AlwaysTrue() { return referenced_field_ids_; }
126+
127+
Result<FieldIdsSetRef> ReferenceVisitor::AlwaysFalse() { return referenced_field_ids_; }
128+
129+
Result<FieldIdsSetRef> ReferenceVisitor::Not(
130+
[[maybe_unused]] const FieldIdsSetRef& child_result) {
131+
return referenced_field_ids_;
132+
}
133+
134+
Result<FieldIdsSetRef> ReferenceVisitor::And(
135+
[[maybe_unused]] const FieldIdsSetRef& left_result,
136+
[[maybe_unused]] const FieldIdsSetRef& right_result) {
137+
return referenced_field_ids_;
138+
}
139+
140+
Result<FieldIdsSetRef> ReferenceVisitor::Or(
141+
[[maybe_unused]] const FieldIdsSetRef& left_result,
142+
[[maybe_unused]] const FieldIdsSetRef& right_result) {
143+
return referenced_field_ids_;
144+
}
145+
146+
Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
147+
const std::shared_ptr<BoundPredicate>& pred) {
148+
referenced_field_ids_.insert(pred->reference()->field_id());
149+
return referenced_field_ids_;
150+
}
151+
152+
Result<FieldIdsSetRef> ReferenceVisitor::Predicate(
153+
[[maybe_unused]] const std::shared_ptr<UnboundPredicate>& pred) {
154+
return referenced_field_ids_;
155+
}
156+
157+
Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
158+
const std::shared_ptr<BoundAggregate>& aggregate) {
159+
referenced_field_ids_.insert(aggregate->reference()->field_id());
160+
return referenced_field_ids_;
161+
}
162+
163+
Result<FieldIdsSetRef> ReferenceVisitor::Aggregate(
164+
[[maybe_unused]] const std::shared_ptr<UnboundAggregate>& aggregate) {
165+
return referenced_field_ids_;
166+
}
167+
116168
} // namespace iceberg

src/iceberg/expression/binder.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
/// \file iceberg/expression/binder.h
2323
/// Bind an expression to a schema.
2424

25+
#include <functional>
26+
#include <unordered_set>
27+
2528
#include "iceberg/expression/expression_visitor.h"
2629

2730
namespace iceberg {
@@ -73,6 +76,31 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor<bool> {
7376
Result<bool> Aggregate(const std::shared_ptr<UnboundAggregate>& aggregate) override;
7477
};
7578

76-
// TODO(gangwu): add the Java parity `ReferenceVisitor`
79+
using FieldIdsSetRef = std::reference_wrapper<std::unordered_set<int32_t>>;
80+
81+
/// \brief Visitor to collect referenced field IDs from an expression.
82+
class ICEBERG_EXPORT ReferenceVisitor : public ExpressionVisitor<FieldIdsSetRef> {
83+
public:
84+
static Result<std::unordered_set<int32_t>> GetReferencedFieldIds(
85+
const std::shared_ptr<Expression>& expr);
86+
87+
Result<FieldIdsSetRef> AlwaysTrue() override;
88+
Result<FieldIdsSetRef> AlwaysFalse() override;
89+
Result<FieldIdsSetRef> Not(const FieldIdsSetRef& child_result) override;
90+
Result<FieldIdsSetRef> And(const FieldIdsSetRef& left_result,
91+
const FieldIdsSetRef& right_result) override;
92+
Result<FieldIdsSetRef> Or(const FieldIdsSetRef& left_result,
93+
const FieldIdsSetRef& right_result) override;
94+
Result<FieldIdsSetRef> Predicate(const std::shared_ptr<BoundPredicate>& pred) override;
95+
Result<FieldIdsSetRef> Predicate(
96+
const std::shared_ptr<UnboundPredicate>& pred) override;
97+
Result<FieldIdsSetRef> Aggregate(
98+
const std::shared_ptr<BoundAggregate>& aggregate) override;
99+
Result<FieldIdsSetRef> Aggregate(
100+
const std::shared_ptr<UnboundAggregate>& aggregate) override;
101+
102+
private:
103+
std::unordered_set<int32_t> referenced_field_ids_;
104+
};
77105

78106
} // namespace iceberg

src/iceberg/test/expression_visitor_test.cc

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,4 +505,210 @@ TEST_F(RewriteNotTest, ComplexExpression) {
505505
EXPECT_EQ(rewritten->op(), Expression::Operation::kOr);
506506
}
507507

508+
class ReferenceVisitorTest : public ExpressionVisitorTest {};
509+
510+
TEST_F(ReferenceVisitorTest, Constants) {
511+
// Constants should have no referenced fields
512+
auto true_expr = Expressions::AlwaysTrue();
513+
ICEBERG_UNWRAP_OR_FAIL(auto refs_true,
514+
ReferenceVisitor::GetReferencedFieldIds(true_expr));
515+
EXPECT_TRUE(refs_true.empty());
516+
517+
auto false_expr = Expressions::AlwaysFalse();
518+
ICEBERG_UNWRAP_OR_FAIL(auto refs_false,
519+
ReferenceVisitor::GetReferencedFieldIds(false_expr));
520+
EXPECT_TRUE(refs_false.empty());
521+
}
522+
523+
TEST_F(ReferenceVisitorTest, UnboundPredicate) {
524+
// Unbound predicates should have no referenced field IDs (not yet bound to schema)
525+
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
526+
ICEBERG_UNWRAP_OR_FAIL(auto refs,
527+
ReferenceVisitor::GetReferencedFieldIds(unbound_pred));
528+
EXPECT_TRUE(refs.empty());
529+
}
530+
531+
TEST_F(ReferenceVisitorTest, BoundPredicate) {
532+
// Bound predicate should return the field ID
533+
auto unbound_pred = Expressions::Equal("name", Literal::String("Alice"));
534+
ICEBERG_UNWRAP_OR_FAIL(auto bound_pred, Bind(unbound_pred));
535+
536+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_pred));
537+
EXPECT_EQ(refs.size(), 1);
538+
EXPECT_EQ(refs.count(2), 1); // name field has id=2
539+
}
540+
541+
TEST_F(ReferenceVisitorTest, MultiplePredicates) {
542+
// Test various predicates with different fields
543+
auto pred_age = Expressions::GreaterThan("age", Literal::Int(25));
544+
ICEBERG_UNWRAP_OR_FAIL(auto bound_age, Bind(pred_age));
545+
ICEBERG_UNWRAP_OR_FAIL(auto refs_age,
546+
ReferenceVisitor::GetReferencedFieldIds(bound_age));
547+
EXPECT_EQ(refs_age.size(), 1);
548+
EXPECT_EQ(refs_age.count(3), 1); // age field has id=3
549+
550+
auto pred_salary = Expressions::LessThan("salary", Literal::Double(50000.0));
551+
ICEBERG_UNWRAP_OR_FAIL(auto bound_salary, Bind(pred_salary));
552+
ICEBERG_UNWRAP_OR_FAIL(auto refs_salary,
553+
ReferenceVisitor::GetReferencedFieldIds(bound_salary));
554+
EXPECT_EQ(refs_salary.size(), 1);
555+
EXPECT_EQ(refs_salary.count(4), 1); // salary field has id=4
556+
}
557+
558+
TEST_F(ReferenceVisitorTest, UnaryPredicates) {
559+
// Test unary predicates
560+
auto pred_is_null = Expressions::IsNull("name");
561+
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_null, Bind(pred_is_null));
562+
ICEBERG_UNWRAP_OR_FAIL(auto refs,
563+
ReferenceVisitor::GetReferencedFieldIds(bound_is_null));
564+
EXPECT_EQ(refs.size(), 1);
565+
EXPECT_EQ(refs.count(2), 1);
566+
567+
auto pred_is_nan = Expressions::IsNaN("salary");
568+
ICEBERG_UNWRAP_OR_FAIL(auto bound_is_nan, Bind(pred_is_nan));
569+
ICEBERG_UNWRAP_OR_FAIL(auto refs_nan,
570+
ReferenceVisitor::GetReferencedFieldIds(bound_is_nan));
571+
EXPECT_EQ(refs_nan.size(), 1);
572+
EXPECT_EQ(refs_nan.count(4), 1);
573+
}
574+
575+
TEST_F(ReferenceVisitorTest, AndExpression) {
576+
// AND expression should return union of field IDs from both sides
577+
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
578+
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
579+
auto and_expr = Expressions::And(pred1, pred2);
580+
581+
ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
582+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));
583+
584+
EXPECT_EQ(refs.size(), 2);
585+
EXPECT_EQ(refs.count(2), 1); // name field
586+
EXPECT_EQ(refs.count(3), 1); // age field
587+
}
588+
589+
TEST_F(ReferenceVisitorTest, OrExpression) {
590+
// OR expression should return union of field IDs from both sides
591+
auto pred1 = Expressions::IsNull("salary");
592+
auto pred2 = Expressions::Equal("active", Literal::Boolean(true));
593+
auto or_expr = Expressions::Or(pred1, pred2);
594+
595+
ICEBERG_UNWRAP_OR_FAIL(auto bound_or, Bind(or_expr));
596+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_or));
597+
598+
EXPECT_EQ(refs.size(), 2);
599+
EXPECT_EQ(refs.count(4), 1); // salary field
600+
EXPECT_EQ(refs.count(5), 1); // active field
601+
}
602+
603+
TEST_F(ReferenceVisitorTest, NotExpression) {
604+
// NOT expression should return field IDs from its child
605+
auto pred = Expressions::Equal("name", Literal::String("Alice"));
606+
auto not_expr = Expressions::Not(pred);
607+
608+
ICEBERG_UNWRAP_OR_FAIL(auto bound_not, Bind(not_expr));
609+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_not));
610+
611+
EXPECT_EQ(refs.size(), 1);
612+
EXPECT_EQ(refs.count(2), 1); // name field
613+
}
614+
615+
TEST_F(ReferenceVisitorTest, ComplexNestedExpression) {
616+
// (name = 'Alice' AND age > 25) OR (salary < 30000 AND active = true)
617+
// Should reference fields: name(2), age(3), salary(4), active(5)
618+
auto pred1 = Expressions::Equal("name", Literal::String("Alice"));
619+
auto pred2 = Expressions::GreaterThan("age", Literal::Int(25));
620+
auto pred3 = Expressions::LessThan("salary", Literal::Double(30000.0));
621+
auto pred4 = Expressions::Equal("active", Literal::Boolean(true));
622+
623+
auto and1 = Expressions::And(pred1, pred2);
624+
auto and2 = Expressions::And(pred3, pred4);
625+
auto complex_or = Expressions::Or(and1, and2);
626+
627+
ICEBERG_UNWRAP_OR_FAIL(auto bound_complex, Bind(complex_or));
628+
ICEBERG_UNWRAP_OR_FAIL(auto refs,
629+
ReferenceVisitor::GetReferencedFieldIds(bound_complex));
630+
631+
EXPECT_EQ(refs.size(), 4);
632+
EXPECT_EQ(refs.count(2), 1); // name field
633+
EXPECT_EQ(refs.count(3), 1); // age field
634+
EXPECT_EQ(refs.count(4), 1); // salary field
635+
EXPECT_EQ(refs.count(5), 1); // active field
636+
}
637+
638+
TEST_F(ReferenceVisitorTest, DuplicateFieldReferences) {
639+
// Multiple predicates referencing the same field
640+
// age > 25 AND age < 50
641+
auto pred1 = Expressions::GreaterThan("age", Literal::Int(25));
642+
auto pred2 = Expressions::LessThan("age", Literal::Int(50));
643+
auto and_expr = Expressions::And(pred1, pred2);
644+
645+
ICEBERG_UNWRAP_OR_FAIL(auto bound_and, Bind(and_expr));
646+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_and));
647+
648+
// Should only contain the field ID once (set semantics)
649+
EXPECT_EQ(refs.size(), 1);
650+
EXPECT_EQ(refs.count(3), 1); // age field
651+
}
652+
653+
TEST_F(ReferenceVisitorTest, SetPredicates) {
654+
// Test In predicate
655+
auto pred_in =
656+
Expressions::In("age", {Literal::Int(25), Literal::Int(30), Literal::Int(35)});
657+
ICEBERG_UNWRAP_OR_FAIL(auto bound_in, Bind(pred_in));
658+
ICEBERG_UNWRAP_OR_FAIL(auto refs_in, ReferenceVisitor::GetReferencedFieldIds(bound_in));
659+
660+
EXPECT_EQ(refs_in.size(), 1);
661+
EXPECT_EQ(refs_in.count(3), 1); // age field
662+
663+
// Test NotIn predicate
664+
auto pred_not_in =
665+
Expressions::NotIn("name", {Literal::String("Alice"), Literal::String("Bob")});
666+
ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, Bind(pred_not_in));
667+
ICEBERG_UNWRAP_OR_FAIL(auto refs_not_in,
668+
ReferenceVisitor::GetReferencedFieldIds(bound_not_in));
669+
670+
EXPECT_EQ(refs_not_in.size(), 1);
671+
EXPECT_EQ(refs_not_in.count(2), 1); // name field
672+
}
673+
674+
TEST_F(ReferenceVisitorTest, MixedBoundAndUnbound) {
675+
// Expression with both bound and unbound predicates
676+
auto bound_pred = Expressions::Equal("name", Literal::String("Alice"));
677+
ICEBERG_UNWRAP_OR_FAIL(auto pred1, Bind(bound_pred));
678+
679+
auto unbound_pred = Expressions::GreaterThan("age", Literal::Int(25));
680+
681+
auto mixed_and = Expressions::And(pred1, unbound_pred);
682+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(mixed_and));
683+
684+
// Should only return field IDs from bound predicates
685+
EXPECT_EQ(refs.size(), 1);
686+
EXPECT_EQ(refs.count(2), 1); // name field only
687+
}
688+
689+
TEST_F(ReferenceVisitorTest, AllFields) {
690+
// Create expression referencing all fields in the schema
691+
auto pred1 = Expressions::NotNull("id");
692+
auto pred2 = Expressions::Equal("name", Literal::String("Test"));
693+
auto pred3 = Expressions::GreaterThan("age", Literal::Int(0));
694+
auto pred4 = Expressions::LessThan("salary", Literal::Double(100000.0));
695+
auto pred5 = Expressions::Equal("active", Literal::Boolean(true));
696+
697+
auto and1 = Expressions::And(pred1, pred2);
698+
auto and2 = Expressions::And(pred3, pred4);
699+
auto and3 = Expressions::And(and1, and2);
700+
auto all_fields = Expressions::And(and3, pred5);
701+
702+
ICEBERG_UNWRAP_OR_FAIL(auto bound_all, Bind(all_fields));
703+
ICEBERG_UNWRAP_OR_FAIL(auto refs, ReferenceVisitor::GetReferencedFieldIds(bound_all));
704+
705+
// Should reference all 5 fields
706+
EXPECT_EQ(refs.size(), 4);
707+
EXPECT_EQ(refs.count(1), 0); // id field is optimized out
708+
EXPECT_EQ(refs.count(2), 1); // name field
709+
EXPECT_EQ(refs.count(3), 1); // age field
710+
EXPECT_EQ(refs.count(4), 1); // salary field
711+
EXPECT_EQ(refs.count(5), 1); // active field
712+
}
713+
508714
} // namespace iceberg

0 commit comments

Comments
 (0)