Skip to content

Commit 1a4151d

Browse files
fixed count, upsert methods for vector
1 parent 76568b8 commit 1a4151d

File tree

3 files changed

+101
-6
lines changed

3 files changed

+101
-6
lines changed

src/Database/Adapter/SQL.php

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3188,7 +3188,18 @@ public function count(Document $collection, array $queries = [], ?int $max = nul
31883188

31893189
$queries = array_map(fn ($query) => clone $query, $queries);
31903190

3191-
$conditions = $this->getSQLConditions($queries, $binds);
3191+
// Extract vector queries (used for ORDER BY) and keep non-vector for WHERE
3192+
$vectorQueries = [];
3193+
$otherQueries = [];
3194+
foreach ($queries as $query) {
3195+
if (in_array($query->getMethod(), Query::VECTOR_TYPES)) {
3196+
$vectorQueries[] = $query;
3197+
} else {
3198+
$otherQueries[] = $query;
3199+
}
3200+
}
3201+
3202+
$conditions = $this->getSQLConditions($otherQueries, $binds);
31923203
if (!empty($conditions)) {
31933204
$where[] = $conditions;
31943205
}
@@ -3206,12 +3217,23 @@ public function count(Document $collection, array $queries = [], ?int $max = nul
32063217
? 'WHERE ' . \implode(' AND ', $where)
32073218
: '';
32083219

3220+
// Add vector distance calculations to ORDER BY (similarity-aware LIMIT)
3221+
$vectorOrders = [];
3222+
foreach ($vectorQueries as $query) {
3223+
$vectorOrder = $this->getVectorDistanceOrder($query, $binds, $alias);
3224+
if ($vectorOrder) {
3225+
$vectorOrders[] = $vectorOrder;
3226+
}
3227+
}
3228+
$sqlOrder = !empty($vectorOrders) ? 'ORDER BY ' . implode(', ', $vectorOrders) : '';
3229+
32093230
$sql = "
32103231
SELECT COUNT(1) as sum FROM (
32113232
SELECT 1
32123233
FROM {$this->getSQLTable($name)} AS {$this->quote($alias)}
3213-
{$sqlWhere}
3214-
{$limit}
3234+
{$sqlWhere}
3235+
{$sqlOrder}
3236+
{$limit}
32153237
) table_count
32163238
";
32173239

src/Database/Database.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,12 +603,12 @@ function (mixed $value) {
603603
return \json_encode(\array_map(\floatval(...), $value));
604604
},
605605
/**
606-
* @param string|null $value
606+
* @param mixed $value
607607
* @return array|null
608608
*/
609-
function (?string $value) {
609+
function (mixed $value) {
610610
if (is_null($value)) {
611-
return null;
611+
return;
612612
}
613613
if (!is_string($value)) {
614614
return $value;

tests/e2e/Adapter/Scopes/VectorTests.php

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,4 +2624,77 @@ public function testVectorQueryInNestedQuery(): void
26242624
// Cleanup
26252625
$database->deleteCollection('vectorNested');
26262626
}
2627+
2628+
public function testVectorQueryCount(): void
2629+
{
2630+
/** @var Database $database */
2631+
$database = static::getDatabase();
2632+
2633+
if (!$database->getAdapter()->getSupportForVectors()) {
2634+
$this->expectNotToPerformAssertions();
2635+
return;
2636+
}
2637+
2638+
$database->createCollection('vectorCount');
2639+
$database->createAttribute('vectorCount', 'embedding', Database::VAR_VECTOR, 3, true);
2640+
2641+
$database->createDocument('vectorCount', new Document([
2642+
'$permissions' => [
2643+
Permission::read(Role::any())
2644+
],
2645+
'embedding' => [1.0, 0.0, 0.0],
2646+
]));
2647+
2648+
$count = $database->count('vectorCount', [
2649+
Query::vectorCosine('embedding', [1.0, 0.0, 0.0]),
2650+
]);
2651+
2652+
$this->assertEquals(1, $count);
2653+
2654+
$database->deleteCollection('vectorCount');
2655+
}
2656+
2657+
public function testVetorUpsert(): void
2658+
{
2659+
/** @var Database $database */
2660+
$database = static::getDatabase();
2661+
2662+
if (!$database->getAdapter()->getSupportForVectors()) {
2663+
$this->expectNotToPerformAssertions();
2664+
return;
2665+
}
2666+
2667+
$database->createCollection('vectorUpsert');
2668+
$database->createAttribute('vectorUpsert', 'embedding', Database::VAR_VECTOR, 3, true);
2669+
2670+
$insertedDoc = $database->upsertDocument('vectorUpsert', new Document([
2671+
'$id' => 'vectorUpsert',
2672+
'$permissions' => [
2673+
Permission::read(Role::any()),
2674+
Permission::update(Role::any())
2675+
],
2676+
'embedding' => [1.0, 0.0, 0.0],
2677+
]));
2678+
2679+
$this->assertEquals([1.0, 0.0, 0.0], $insertedDoc->getAttribute('embedding'));
2680+
2681+
$insertedDoc = $database->getDocument('vectorUpsert', 'vectorUpsert');
2682+
$this->assertEquals([1.0, 0.0, 0.0], $insertedDoc->getAttribute('embedding'));
2683+
2684+
$updatedDoc = $database->upsertDocument('vectorUpsert', new Document([
2685+
'$id' => 'vectorUpsert',
2686+
'$permissions' => [
2687+
Permission::read(Role::any()),
2688+
Permission::update(Role::any())
2689+
],
2690+
'embedding' => [2.0, 0.0, 0.0],
2691+
]));
2692+
2693+
$this->assertEquals([2.0, 0.0, 0.0], $updatedDoc->getAttribute('embedding'));
2694+
2695+
$updatedDoc = $database->getDocument('vectorUpsert', 'vectorUpsert');
2696+
$this->assertEquals([2.0, 0.0, 0.0], $updatedDoc->getAttribute('embedding'));
2697+
2698+
$database->deleteCollection('vectorUpsert');
2699+
}
26272700
}

0 commit comments

Comments
 (0)