Skip to content

Commit 2854b04

Browse files
SecAI-Hubclaude
andcommitted
Add weight distribution statistical fingerprinting for trojan/corruption detection (M28)
Phase 8: Per-tensor statistical analysis (mean, variance, kurtosis, zero-fraction) for GGUF and safetensors files. Flags abnormal weight distributions that may indicate trojan patches, corrupted weights, or steganographic payloads. Supports F32, F16, and Q8_0 quantized tensors. Integrated into check_static_scan pipeline. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1408e7f commit 2854b04

2 files changed

Lines changed: 560 additions & 1 deletion

File tree

services/quarantine/quarantine/pipeline.py

Lines changed: 364 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,367 @@ def _run_modelaudit(filepath: Path) -> dict:
602602
return {"passed": True, "scanner": "modelaudit", "note": f"modelaudit error: {e}"}
603603

604604

605+
# ---------------------------------------------------------------------------
606+
# Weight distribution statistical fingerprinting
607+
# ---------------------------------------------------------------------------
608+
609+
# GGUF type IDs → (numpy dtype string, byte size)
610+
_GGUF_TYPE_INFO = {
611+
0: ("f32", 4), # GGUF_TYPE_F32
612+
1: ("f16", 2), # GGUF_TYPE_F16
613+
6: ("f32", 4), # GGUF_TYPE_F64 → read as f32 pairs
614+
# Quantized types are not directly interpretable as floats;
615+
# we dequantize Q8_0 blocks and skip others.
616+
8: ("q8_0", 34), # GGUF_TYPE_Q8_0: 34 bytes per block of 32 values
617+
}
618+
619+
# Thresholds for anomaly detection
620+
WEIGHT_STATS_MAX_KURTOSIS = 100.0 # Extremely peaked = suspicious
621+
WEIGHT_STATS_MAX_MEAN_ABS = 10.0 # Unusually large mean
622+
WEIGHT_STATS_MIN_VARIANCE = 1e-12 # All-zero or constant tensor
623+
WEIGHT_STATS_MAX_ZERO_FRACTION = 0.99 # Nearly all zeros = possibly corrupted
624+
625+
626+
def _analyze_weight_distribution(artifact_path: Path) -> dict:
627+
"""Statistical fingerprinting of model weights.
628+
629+
Reads tensor data from GGUF or safetensors files and computes per-layer
630+
statistics (mean, variance, kurtosis, zero-fraction). Flags anomalies
631+
that may indicate:
632+
- Trojan patches (localized extreme values)
633+
- Corrupted/zeroed weights
634+
- Steganographic payloads (unusual distribution shape)
635+
"""
636+
ext = artifact_path.suffix.lower()
637+
try:
638+
if ext == ".gguf":
639+
return _analyze_gguf_weights(artifact_path)
640+
elif ext == ".safetensors":
641+
return _analyze_safetensors_weights(artifact_path)
642+
else:
643+
return {"passed": True, "note": f"weight analysis not supported for {ext}"}
644+
except Exception as e:
645+
log.warning("weight distribution analysis failed: %s", e)
646+
return {"passed": True, "note": f"analysis error (non-fatal): {e}"}
647+
648+
649+
def _compute_tensor_stats(data: bytes, count: int, dtype: str = "f32") -> dict | None:
650+
"""Compute statistics for a raw tensor buffer.
651+
652+
Returns dict with mean, variance, kurtosis, zero_fraction, or None on error.
653+
"""
654+
if count == 0:
655+
return None
656+
657+
# Sample up to 1M values for large tensors (performance)
658+
max_samples = 1_000_000
659+
660+
if dtype == "f32":
661+
fmt_size = 4
662+
fmt_char = "f"
663+
elif dtype == "f16":
664+
# Read f16 as unsigned 16-bit, convert manually
665+
fmt_size = 2
666+
fmt_char = "e" # IEEE 754 half-precision
667+
elif dtype == "q8_0":
668+
# Q8_0 dequantization: each block = 2-byte f16 scale + 32 int8 values
669+
return _dequant_q8_0_stats(data, count)
670+
else:
671+
return None
672+
673+
actual_count = min(count, len(data) // fmt_size)
674+
if actual_count < 16:
675+
return None
676+
677+
sample_count = min(actual_count, max_samples)
678+
# Use struct to unpack values (no numpy dependency)
679+
step = max(1, actual_count // sample_count)
680+
values = []
681+
for i in range(0, actual_count, step):
682+
offset = i * fmt_size
683+
if offset + fmt_size > len(data):
684+
break
685+
try:
686+
val = struct.unpack_from(f"<{fmt_char}", data, offset)[0]
687+
if math.isfinite(val):
688+
values.append(val)
689+
except struct.error:
690+
break
691+
692+
if len(values) < 16:
693+
return None
694+
695+
return _stats_from_values(values)
696+
697+
698+
def _dequant_q8_0_stats(data: bytes, _element_count: int) -> dict | None:
699+
"""Dequantize Q8_0 blocks and compute stats.
700+
701+
Q8_0 format: each block contains 2-byte f16 scale + 32 int8 quantized values.
702+
"""
703+
block_size = 34 # 2 (scale) + 32 (quants)
704+
n_blocks = len(data) // block_size
705+
if n_blocks == 0:
706+
return None
707+
708+
max_blocks = 32768 # Sample ~1M values
709+
step = max(1, n_blocks // max_blocks)
710+
values = []
711+
712+
for bi in range(0, n_blocks, step):
713+
offset = bi * block_size
714+
if offset + block_size > len(data):
715+
break
716+
try:
717+
scale = struct.unpack_from("<e", data, offset)[0]
718+
if not math.isfinite(scale):
719+
continue
720+
for qi in range(32):
721+
qval = struct.unpack_from("b", data, offset + 2 + qi)[0]
722+
values.append(scale * qval)
723+
except struct.error:
724+
break
725+
726+
if len(values) < 16:
727+
return None
728+
729+
return _stats_from_values(values)
730+
731+
732+
def _stats_from_values(values: list) -> dict:
733+
"""Compute mean, variance, kurtosis, zero fraction from a list of floats."""
734+
n = len(values)
735+
mean = sum(values) / n
736+
var = sum((v - mean) ** 2 for v in values) / n
737+
738+
# Excess kurtosis (normal distribution = 0)
739+
if var > 0:
740+
m4 = sum((v - mean) ** 4 for v in values) / n
741+
kurtosis = m4 / (var ** 2) - 3.0
742+
else:
743+
kurtosis = 0.0
744+
745+
zero_count = sum(1 for v in values if v == 0.0)
746+
747+
return {
748+
"mean": round(mean, 6),
749+
"variance": round(var, 6),
750+
"kurtosis": round(kurtosis, 4),
751+
"zero_fraction": round(zero_count / n, 4),
752+
"samples": n,
753+
}
754+
755+
756+
def _analyze_gguf_weights(filepath: Path) -> dict:
757+
"""Parse GGUF tensor info and compute weight statistics."""
758+
anomalies = []
759+
tensor_stats = []
760+
761+
try:
762+
with open(filepath, "rb") as f:
763+
magic = f.read(4)
764+
if magic != GGUF_MAGIC:
765+
return {"passed": True, "note": "not a valid GGUF file"}
766+
767+
struct.unpack("<I", f.read(4)) # version, already validated
768+
n_tensors = struct.unpack("<Q", f.read(8))[0]
769+
n_kv = struct.unpack("<Q", f.read(8))[0]
770+
771+
# Skip metadata KV pairs
772+
for _ in range(n_kv):
773+
key_len = struct.unpack("<Q", f.read(8))[0]
774+
f.seek(key_len, 1) # skip key
775+
val_type = struct.unpack("<I", f.read(4))[0]
776+
_skip_gguf_value(f, val_type)
777+
778+
# Read tensor info entries
779+
tensor_infos = []
780+
for _ in range(min(n_tensors, 2000)): # cap to prevent abuse
781+
name_len = struct.unpack("<Q", f.read(8))[0]
782+
name = f.read(name_len).decode("utf-8", errors="replace")
783+
n_dims = struct.unpack("<I", f.read(4))[0]
784+
dims = [struct.unpack("<Q", f.read(8))[0] for _ in range(n_dims)]
785+
dtype_id = struct.unpack("<I", f.read(4))[0]
786+
offset = struct.unpack("<Q", f.read(8))[0]
787+
element_count = 1
788+
for d in dims:
789+
element_count *= d
790+
tensor_infos.append({
791+
"name": name,
792+
"dims": dims,
793+
"dtype_id": dtype_id,
794+
"offset": offset,
795+
"element_count": element_count,
796+
})
797+
798+
# Data starts at alignment boundary after header
799+
header_end = f.tell()
800+
alignment = 32 # GGUF default alignment
801+
data_start = ((header_end + alignment - 1) // alignment) * alignment
802+
803+
# Analyze a sample of tensors (largest ones are most informative)
804+
# Sort by element count descending, take top 20
805+
tensor_infos.sort(key=lambda t: t["element_count"], reverse=True)
806+
sample_tensors = tensor_infos[:20]
807+
808+
for tinfo in sample_tensors:
809+
dtype_id = tinfo["dtype_id"]
810+
if dtype_id not in _GGUF_TYPE_INFO:
811+
continue # Skip unsupported quantization types
812+
813+
dtype_name, type_size = _GGUF_TYPE_INFO[dtype_id]
814+
if dtype_name == "q8_0":
815+
n_blocks = (tinfo["element_count"] + 31) // 32
816+
data_size = n_blocks * 34
817+
else:
818+
data_size = tinfo["element_count"] * type_size
819+
820+
# Cap read size to 32MB per tensor
821+
read_size = min(data_size, 32 * 1024 * 1024)
822+
823+
f.seek(data_start + tinfo["offset"])
824+
raw = f.read(read_size)
825+
826+
stats = _compute_tensor_stats(raw, tinfo["element_count"], dtype_name)
827+
if stats is None:
828+
continue
829+
830+
stats["name"] = tinfo["name"]
831+
tensor_stats.append(stats)
832+
833+
# Check anomaly thresholds
834+
issues = _check_weight_anomalies(tinfo["name"], stats)
835+
anomalies.extend(issues)
836+
837+
except (struct.error, OSError) as e:
838+
return {"passed": True, "note": f"GGUF weight parse error (non-fatal): {e}"}
839+
840+
if anomalies:
841+
return {
842+
"passed": False,
843+
"reason": f"weight distribution anomalies: {'; '.join(anomalies[:3])}",
844+
"anomalies": anomalies,
845+
"tensors_analyzed": len(tensor_stats),
846+
}
847+
848+
return {
849+
"passed": True,
850+
"tensors_analyzed": len(tensor_stats),
851+
"tensor_stats": tensor_stats[:5], # Include top 5 for provenance
852+
}
853+
854+
855+
def _analyze_safetensors_weights(filepath: Path) -> dict:
856+
"""Parse safetensors header and compute weight statistics on tensors."""
857+
anomalies = []
858+
tensor_stats = []
859+
860+
try:
861+
with open(filepath, "rb") as f:
862+
header_len = struct.unpack("<Q", f.read(8))[0]
863+
if header_len > SAFETENSORS_MAX_HEADER:
864+
return {"passed": True, "note": "header too large for weight analysis"}
865+
header_raw = f.read(header_len)
866+
header = json.loads(header_raw)
867+
data_start = 8 + header_len
868+
869+
# Collect tensor metadata
870+
tensors = []
871+
for name, info in header.items():
872+
if name == "__metadata__":
873+
continue
874+
dtype = info.get("dtype", "")
875+
offsets = info.get("data_offsets", [0, 0])
876+
start, end = offsets[0], offsets[1]
877+
size_bytes = end - start
878+
tensors.append({
879+
"name": name,
880+
"dtype": dtype,
881+
"offset": start,
882+
"size_bytes": size_bytes,
883+
})
884+
885+
# Sort by size, analyze top 20
886+
tensors.sort(key=lambda t: t["size_bytes"], reverse=True)
887+
sample_tensors = tensors[:20]
888+
889+
for tinfo in sample_tensors:
890+
dtype = tinfo["dtype"]
891+
if dtype == "F32":
892+
fmt_dtype = "f32"
893+
elem_size = 4
894+
elif dtype == "F16":
895+
fmt_dtype = "f16"
896+
elem_size = 2
897+
elif dtype == "BF16":
898+
# BF16 not directly supported by struct; skip
899+
continue
900+
else:
901+
continue
902+
903+
element_count = tinfo["size_bytes"] // elem_size
904+
read_size = min(tinfo["size_bytes"], 32 * 1024 * 1024)
905+
906+
f.seek(data_start + tinfo["offset"])
907+
raw = f.read(read_size)
908+
909+
stats = _compute_tensor_stats(raw, element_count, fmt_dtype)
910+
if stats is None:
911+
continue
912+
913+
stats["name"] = tinfo["name"]
914+
tensor_stats.append(stats)
915+
916+
issues = _check_weight_anomalies(tinfo["name"], stats)
917+
anomalies.extend(issues)
918+
919+
except (struct.error, OSError, json.JSONDecodeError) as e:
920+
return {"passed": True, "note": f"safetensors weight parse error (non-fatal): {e}"}
921+
922+
if anomalies:
923+
return {
924+
"passed": False,
925+
"reason": f"weight distribution anomalies: {'; '.join(anomalies[:3])}",
926+
"anomalies": anomalies,
927+
"tensors_analyzed": len(tensor_stats),
928+
}
929+
930+
return {
931+
"passed": True,
932+
"tensors_analyzed": len(tensor_stats),
933+
"tensor_stats": tensor_stats[:5],
934+
}
935+
936+
937+
def _check_weight_anomalies(tensor_name: str, stats: dict) -> list:
938+
"""Check a tensor's statistics against anomaly thresholds."""
939+
issues = []
940+
941+
if abs(stats["mean"]) > WEIGHT_STATS_MAX_MEAN_ABS:
942+
issues.append(
943+
f"{tensor_name}: abnormal mean ({stats['mean']:.4f})"
944+
)
945+
946+
if stats["variance"] < WEIGHT_STATS_MIN_VARIANCE and stats["zero_fraction"] < 0.99:
947+
issues.append(
948+
f"{tensor_name}: near-zero variance ({stats['variance']:.2e}) with non-zero values"
949+
)
950+
951+
if stats["kurtosis"] > WEIGHT_STATS_MAX_KURTOSIS:
952+
issues.append(
953+
f"{tensor_name}: extreme kurtosis ({stats['kurtosis']:.2f}), possible trojan patch"
954+
)
955+
956+
if stats["zero_fraction"] > WEIGHT_STATS_MAX_ZERO_FRACTION:
957+
issues.append(
958+
f"{tensor_name}: {stats['zero_fraction']*100:.1f}% zeros, possibly corrupted"
959+
)
960+
961+
return issues
962+
963+
605964
def check_static_scan(artifact_path: Path, policy: dict | None = None) -> dict:
606-
"""Stage 5: Run modelscan + fickling + modelaudit + entropy analysis."""
965+
"""Stage 5: Run modelscan + fickling + modelaudit + entropy + weight analysis."""
607966
if policy is None:
608967
policy = {}
609968
results = {}
@@ -628,6 +987,10 @@ def check_static_scan(artifact_path: Path, policy: dict | None = None) -> dict:
628987
entropy_result = _check_file_entropy(artifact_path)
629988
results["entropy"] = entropy_result
630989

990+
# 6. Weight distribution analysis (new, no external dep)
991+
weight_result = _analyze_weight_distribution(artifact_path)
992+
results["weight_stats"] = weight_result
993+
631994
# Overall: fail if ANY scanner fails
632995
failed = [k for k, v in results.items() if not v.get("passed", True)]
633996
if failed:

0 commit comments

Comments
 (0)