-
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathhsdpy_example.py
More file actions
114 lines (92 loc) · 4.32 KB
/
hsdpy_example.py
File metadata and controls
114 lines (92 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
import hsdpy
import numpy as np
import sys
VECTOR_DIM = 5
def print_system_info():
"""Prints available HsdPy library and system information."""
print("--- Library Info ---")
try:
# Use the get_library_info function which includes backend info
info = hsdpy.get_library_info()
print(f"Library Path: {info.get('lib_path', 'N/A')}")
print(f"System: {info.get('system', 'N/A')}")
print(f"Architecture: {info.get('arch', 'N/A')}")
print(f"Active Backend: {info.get('backend', 'N/A')}")
except Exception as e:
print(f"Could not retrieve full library info: {e}", file=sys.stderr)
# Fallback to just getting the backend if info failed
try:
backend = hsdpy.get_backend()
print(f"Active Backend (fallback): {backend}")
except Exception as be:
print(f"Could not retrieve backend info: {be}", file=sys.stderr)
print("---")
def run_and_print(description, func, *args, is_integer_result=False):
"""
Runs an HsdPy function, prints the result or catches/prints errors.
Args:
description (str): Text description of the operation.
func (callable): The HsdPy function to call.
*args: Arguments to pass to the HsdPy function (e.g., numpy arrays).
is_integer_result (bool): True if the result should be formatted as an integer.
"""
try:
result = func(*args)
if is_integer_result:
# Default int formatting is usually sufficient
print(f"{description}: {result}")
else:
# Format floats similar to the C example
print(f"{description}: {result:.4f}")
except hsdpy.HsdError as e:
print(f"ERROR: {description} failed - HsdError Status={e.status_code}, Msg='{e.message}'",
file=sys.stderr)
except NotImplementedError as e:
print(f"ERROR: {description} failed - Function not available in C library. {e}",
file=sys.stderr)
except Exception as e:
# Catch other potential errors like TypeError, ValueError from validation
print(f"ERROR: {description} failed - {type(e).__name__}: {e}", file=sys.stderr)
def main():
"""Main function to demonstrate HsdPy API."""
# --- Vector Definitions ---
# Float vectors (for Euclidean, Manhattan, Cosine, Dot)
vec_a_f32 = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
vec_b_f32 = np.array([5.0, 4.0, 3.0, 2.0, 1.0], dtype=np.float32)
# Binary vectors (uint8 for Hamming)
vec_a_bin_u8 = np.array([1, 1, 0, 1, 0], dtype=np.uint8)
vec_b_bin_u8 = np.array([1, 0, 1, 1, 1], dtype=np.uint8)
# Binary-style vectors (uint16 for Jaccard - binary case)
vec_a_bin_u16 = np.array([1, 1, 0, 1, 0], dtype=np.uint16)
vec_b_bin_u16 = np.array([1, 0, 1, 1, 1], dtype=np.uint16)
# Weighted/non-binary vectors (uint16 for Jaccard/Tanimoto - non-binary case)
vec_a_weighted_u16 = np.array([3, 5, 0, 2, 0], dtype=np.uint16)
vec_b_weighted_u16 = np.array([1, 5, 4, 2, 3], dtype=np.uint16)
# --- Show System Info ---
print_system_info()
# --- Run Calculations ---
print("\n--- Calculations (Using Auto-Detected Backend) ---")
# Distance Metrics
run_and_print("Squared Euclidean Distance (f32)",
hsdpy.dist_sqeuclidean_f32, vec_a_f32, vec_b_f32)
run_and_print("Manhattan Distance (f32)",
hsdpy.dist_manhattan_f32, vec_a_f32, vec_b_f32)
run_and_print("Hamming Distance (u8 binary)",
hsdpy.dist_hamming_u8, vec_a_bin_u8, vec_b_bin_u8,
is_integer_result=True)
# Similarity Measures
run_and_print("Dot Product Similarity (f32)",
hsdpy.sim_dot_f32, vec_a_f32, vec_b_f32)
run_and_print("Cosine Similarity (f32)",
hsdpy.sim_cosine_f32, vec_a_f32, vec_b_f32)
# Jaccard/Tanimoto Examples
run_and_print("Jaccard Similarity (u16 binary input)",
hsdpy.sim_jaccard_u16, vec_a_bin_u16, vec_b_bin_u16)
run_and_print("Tanimoto Coefficient (u16 non-binary input)",
hsdpy.sim_jaccard_u16, vec_a_weighted_u16, vec_b_weighted_u16)
# --- Backend Selection ---
print("\n--- Backend Selection ---")
print(f"The active backend detected was: {hsdpy.get_backend()}")
if __name__ == "__main__":
main()