Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,82 @@ void test_recognizerFeedback_withInvalidRecognizerId_fallsBackToAllRecognizers(T
}
}

@Test
void test_tagUsageCountReflectsActualAssets(TestNamespace ns) {
OpenMetadataClient client = SdkClients.adminClient();
Classification classification = createClassification(ns);

CreateTag request = new CreateTag();
request.setName(ns.shortPrefix("usage_count_tag"));
request.setClassification(classification.getFullyQualifiedName());
request.setDescription("Tag for usage count verification");
Tag tag = createEntity(request);

Tag initialFetch = client.tags().get(tag.getId().toString(), "usageCount");
assertEquals(0, initialFetch.getUsageCount(), "Usage count must be 0 before applying tag");

org.openmetadata.schema.entity.services.DatabaseService dbService =
createDatabaseService(ns, "usage_count_svc");
org.openmetadata.schema.entity.data.Database database =
createDatabase(ns, dbService.getFullyQualifiedName());
org.openmetadata.schema.entity.data.DatabaseSchema schema =
createDatabaseSchema(ns, database.getFullyQualifiedName());

org.openmetadata.schema.type.TagLabel tagLabel =
new org.openmetadata.schema.type.TagLabel()
.withTagFQN(tag.getFullyQualifiedName())
.withSource(org.openmetadata.schema.type.TagLabel.TagSource.CLASSIFICATION)
.withLabelType(org.openmetadata.schema.type.TagLabel.LabelType.MANUAL);

org.openmetadata.schema.api.data.CreateTable createTable =
new org.openmetadata.schema.api.data.CreateTable();
createTable.setName(ns.shortPrefix("table_one"));
createTable.setDatabaseSchema(schema.getFullyQualifiedName());
createTable.setColumns(
List.of(
new org.openmetadata.schema.type.Column()
.withName("id")
.withDataType(org.openmetadata.schema.type.ColumnDataType.BIGINT)));
createTable.setTags(List.of(tagLabel));
SdkClients.adminClient().tables().create(createTable);

Tag afterOneAsset = client.tags().get(tag.getId().toString(), "usageCount");
assertEquals(1, afterOneAsset.getUsageCount(), "Usage count must be 1 after tagging one table");

org.openmetadata.schema.api.data.CreateTable createTable2 =
new org.openmetadata.schema.api.data.CreateTable();
createTable2.setName(ns.shortPrefix("table_two"));
createTable2.setDatabaseSchema(schema.getFullyQualifiedName());
createTable2.setColumns(
List.of(
new org.openmetadata.schema.type.Column()
.withName("id")
.withDataType(org.openmetadata.schema.type.ColumnDataType.BIGINT)));
createTable2.setTags(List.of(tagLabel));
SdkClients.adminClient().tables().create(createTable2);

Tag afterTwoAssets = client.tags().get(tag.getId().toString(), "usageCount");
assertEquals(
2, afterTwoAssets.getUsageCount(), "Usage count must be 2 after tagging two tables");

Tag fetchedViaList =
client
.tags()
.list(
new ListParams()
.setFields("usageCount")
.setParent(classification.getFullyQualifiedName()))
.getData()
.stream()
.filter(t -> t.getId().equals(tag.getId()))
.findFirst()
.orElseThrow();
assertEquals(
2,
fetchedViaList.getUsageCount(),
"Usage count in list response (batchFetchUsageCounts path) must match");
}

private org.openmetadata.schema.entity.services.DatabaseService createDatabaseService(
TestNamespace ns, String serviceName) {
org.openmetadata.schema.api.services.CreateDatabaseService createService =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,38 +710,50 @@ private Map<UUID, EntityReference> batchFetchParents(List<Tag> tags) {
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

record UsageCountQuery(String template, Map<String, Object> bindings) {}

// Package-private for testing
UsageCountQuery buildUsageCountQuery(List<String> tagFQNs) {
var sb = new StringBuilder();
Map<String, Object> bindings = new HashMap<>();
bindings.put("source", TagSource.CLASSIFICATION.ordinal());

for (int i = 0; i < tagFQNs.size(); i++) {
if (i > 0) {
sb.append(" UNION ALL ");
}
sb.append(
"""
SELECT :tagFQN_%d as tagFQN,
COUNT(DISTINCT targetFQNHash) as count
FROM tag_usage
WHERE source = :source
AND (tagFQNHash = :hash_%d OR tagFQNHash LIKE CONCAT(:hash_%d, '.%%'))
"""
.formatted(i, i, i));
bindings.put("tagFQN_" + i, tagFQNs.get(i));
bindings.put("hash_" + i, FullyQualifiedName.buildHash(tagFQNs.get(i)));
}
return new UsageCountQuery(sb.toString(), Collections.unmodifiableMap(bindings));
}

private Map<String, Integer> batchFetchUsageCounts(List<Tag> tags) {
if (tags == null || tags.isEmpty()) {
return Map.of();
}

// Build and execute a single query for all tags
var tagFQNs = tags.stream().map(Tag::getFullyQualifiedName).toList();

// Build UNION query that gets counts for all tags in one go
var queryBuilder = new StringBuilder();
tagFQNs.forEach(
tagFQN -> {
if (!queryBuilder.isEmpty()) {
queryBuilder.append(" UNION ALL ");
}
var escapedFQN = tagFQN.replace("'", "''");
queryBuilder.append(
"""
SELECT '%s' as tagFQN,
COUNT(DISTINCT targetFQNHash) as count
FROM tag_usage
WHERE source = %d
AND (tagFQNHash = MD5('%s') OR tagFQNHash LIKE CONCAT(MD5('%s'), '.%%'))
"""
.formatted(
escapedFQN, TagSource.CLASSIFICATION.ordinal(), escapedFQN, escapedFQN));
});

try {
var usageCountQuery = buildUsageCountQuery(tagFQNs);
var results =
Entity.getJdbi()
.withHandle(handle -> handle.createQuery(queryBuilder.toString()).mapToMap().list());
.withHandle(
handle -> {
var query = handle.createQuery(usageCountQuery.template());
usageCountQuery.bindings().forEach((k, v) -> query.bind(k, v.toString()));
return query.mapToMap().list();
});

return results.stream()
.filter(row -> row.get("tagFQN") != null)
Expand All @@ -754,7 +766,6 @@ private Map<String, Integer> batchFetchUsageCounts(List<Tag> tags) {
}));
} catch (Exception e) {
LOG.error("Error batch fetching usage counts", e);
// Fall back to individual queries
return daoCollection
.tagUsageDAO()
.getTagCountsBulk(TagSource.CLASSIFICATION.ordinal(), tagFQNs);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.openmetadata.service.jdbi3;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

import java.nio.charset.StandardCharsets;
Expand All @@ -19,6 +21,7 @@
import org.openmetadata.schema.type.Recognizer;
import org.openmetadata.schema.utils.ResultList;
import org.openmetadata.service.exception.BadCursorException;
import org.openmetadata.service.util.FullyQualifiedName;

public class TagRepositoryUnitTest {
private static final TagRepository tagRepository;
Expand Down Expand Up @@ -269,6 +272,91 @@ void test_paginationBoundaries_lastPage() {
assertNull(thirdPage.getPaging().getAfter());
}

// ===================================================================
// USAGE COUNT QUERY TESTS — verifies batchFetchUsageCounts uses correct hash
//
// tag_usage.tagFQNHash stores hashes produced by FullyQualifiedName.buildHash:
// each FQN segment is hashed individually and joined with ".".
// e.g. "PII.Sensitive" → hash("PII") + "." + hash("Sensitive")
//
// The original bug used MySQL's MD5(fullFqnString) directly in the SQL, which
// computes MD5("PII.Sensitive") — a flat 32-char hex string that never matches
// any row, returning usageCount = 0 for every tag.
// ===================================================================

private static final TagRepository realTagRepository;

static {
realTagRepository = Mockito.mock(TagRepository.class);
when(realTagRepository.buildUsageCountQuery(Mockito.anyList())).thenCallRealMethod();
}

@Test
void test_usageCountQuery_bindingsContainCorrectHashNotRawFqn() {
String tagFqn = "PII.Sensitive";
TagRepository.UsageCountQuery result = realTagRepository.buildUsageCountQuery(List.of(tagFqn));
String expectedHash = FullyQualifiedName.buildHash(tagFqn);

assertEquals(
expectedHash, result.bindings().get("hash_0"), "Binding hash_0 must use buildHash result");
assertTrue(expectedHash.contains("."), "buildHash of a multi-segment FQN must contain '.'");
}

@Test
void test_usageCountQuery_templateUsesNamedParams() {
String tagFqn = "PII.Sensitive";
TagRepository.UsageCountQuery result = realTagRepository.buildUsageCountQuery(List.of(tagFqn));
String template = result.template();
String expectedHash = FullyQualifiedName.buildHash(tagFqn);

assertTrue(template.contains(":hash_0"), "Template must use named param :hash_0, not raw hash");
assertTrue(
template.contains(":tagFQN_0"), "Template must use named param :tagFQN_0, not raw FQN");
assertFalse(
template.contains("tagFQNHash = '" + tagFqn + "'"),
"Template must not embed raw FQN as hash value");
assertFalse(
template.contains("tagFQNHash = '" + expectedHash + "'"),
"Template must use named params, not inline hash literals");
}

@Test
void test_usageCountQuery_multipleTagsUnionAll() {
List<String> tagFqns = List.of("PII.Sensitive", "PII.Personal", "Tier.Tier1");
TagRepository.UsageCountQuery result = realTagRepository.buildUsageCountQuery(tagFqns);
String template = result.template();

assertEquals(
2, countOccurrences(template, "UNION ALL"), "3 tags must produce 2 UNION ALL joins");
for (int i = 0; i < tagFqns.size(); i++) {
String expectedHash = FullyQualifiedName.buildHash(tagFqns.get(i));
assertEquals(
expectedHash,
result.bindings().get("hash_" + i),
"Binding hash_" + i + " must use buildHash result for: " + tagFqns.get(i));
assertEquals(
tagFqns.get(i),
result.bindings().get("tagFQN_" + i),
"Binding tagFQN_" + i + " must equal the original FQN");
}
}

@Test
void test_usageCountQuery_emptyList_returnsEmptyTemplate() {
TagRepository.UsageCountQuery result = realTagRepository.buildUsageCountQuery(List.of());
assertTrue(result.template().isEmpty(), "Empty tag list must produce empty query template");
}

private static int countOccurrences(String text, String pattern) {
int count = 0;
int idx = 0;
while ((idx = text.indexOf(pattern, idx)) != -1) {
count++;
idx += pattern.length();
}
return count;
}

@Test
void test_backwardPaginationFromMiddle_returnsCorrectOrder() {
Tag tag = createTagWithRecognizers(30);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,4 +648,54 @@ describe('ClassificationDetails', () => {
expect(screen.getByTestId('header')).toBeInTheDocument();
expect(screen.getByTestId('tags-table')).toBeInTheDocument();
});

it('should fetch tags with usageCount field included', async () => {
render(
<MemoryRouter>
<ClassificationDetails {...defaultProps} />
</MemoryRouter>
);

await waitFor(() => expect(mockGetTags).toHaveBeenCalled());

expect(mockGetTags).toHaveBeenCalledWith(
expect.objectContaining({
fields: expect.stringContaining('usageCount'),
})
);
});

it('should show usageCount in the table when tags have asset counts', async () => {
const tagsWithUsage: Tag[] = [
{
id: 'tag-3',
name: 'Tag3',
displayName: 'Tag Three',
description: 'Third tag',
fullyQualifiedName: 'TestClassification.Tag3',
provider: ProviderType.User,
usageCount: 5,
},
];
mockGetTags.mockResolvedValueOnce({
data: tagsWithUsage,
paging: { total: 1 },
});

render(
<MemoryRouter>
<ClassificationDetails {...defaultProps} />
</MemoryRouter>
);

await waitFor(() =>
expect(screen.getByTestId('tag-row-Tag3')).toBeInTheDocument()
);

expect(mockGetTags).toHaveBeenCalledWith(
expect.objectContaining({
fields: expect.stringContaining('usageCount'),
})
);
});
});
Loading
Loading