@@ -847,3 +847,385 @@ mod tests {
847847 }
848848 }
849849}
850+
851+ // ============================================================================
852+ // DeepNSM Evaluation Pipeline — transcoded from Python DeepNSM
853+ // ============================================================================
854+
855+ /// The full NSM primes set including multi-word primes (from Python utils.py).
856+ static NSM_PRIMES_SET : LazyLock < HashSet < & ' static str > > = LazyLock :: new ( || {
857+ [
858+ "i" , "you" , "someone" , "people" , "something" , "thing" , "body" , "kind" , "part" ,
859+ "this" , "the same" , "other" , "else" , "another" , "one" , "two" , "some" , "all" ,
860+ "much" , "many" , "little" , "few" , "good" , "bad" , "big" , "small" , "think" , "know" ,
861+ "want" , "don't want" , "feel" , "see" , "hear" , "say" , "words" , "true" , "do" ,
862+ "happen" , "move" , "there" , "is" , "be" , "mine" , "live" , "die" , "when" , "time" ,
863+ "now" , "before" , "after" , "a long time" , "a short time" , "for some time" ,
864+ "moment" , "where" , "place" , "here" , "above" , "below" , "far" , "near" , "side" ,
865+ "inside" , "touch" , "not" , "maybe" , "can" , "because" , "if" , "very" , "more" ,
866+ "like" , "as" , "way" , "said" ,
867+ ] . into_iter ( ) . collect ( )
868+ } ) ;
869+
870+ /// Check if a word is an NSM semantic prime.
871+ pub fn is_nsm_prime ( word : & str ) -> bool {
872+ NSM_PRIMES_SET . contains ( word. to_lowercase ( ) . as_str ( ) )
873+ }
874+
875+ /// English stopwords excluding NSM primes. `LazyLock` one-time init.
876+ static STOP_WORDS : LazyLock < HashSet < & ' static str > > = LazyLock :: new ( || {
877+ let sw: HashSet < & str > = [
878+ "a" , "an" , "and" , "are" , "at" , "been" , "but" , "by" , "did" , "does" ,
879+ "doing" , "down" , "during" , "each" , "for" , "from" , "further" , "had" ,
880+ "has" , "having" , "he" , "her" , "herself" , "him" , "himself" , "his" ,
881+ "how" , "in" , "into" , "it" , "its" , "itself" , "just" , "me" , "my" ,
882+ "myself" , "no" , "nor" , "of" , "off" , "on" , "once" , "only" , "or" ,
883+ "our" , "ours" , "ourselves" , "out" , "over" , "own" , "re" , "s" , "she" ,
884+ "should" , "so" , "such" , "t" , "than" , "that" , "the" , "their" ,
885+ "theirs" , "them" , "themselves" , "then" , "these" , "they" , "those" ,
886+ "through" , "to" , "too" , "under" , "until" , "up" , "ve" , "was" , "we" ,
887+ "were" , "what" , "which" , "while" , "who" , "whom" , "why" , "will" ,
888+ "with" , "won" , "would" , "your" , "yours" , "yourself" , "yourselves" ,
889+ ] . into_iter ( ) . collect ( ) ;
890+ sw. into_iter ( ) . filter ( |w| !NSM_PRIMES_SET . contains ( * w) ) . collect ( )
891+ } ) ;
892+
893+ /// Check if a word is a stopword (but not an NSM prime).
894+ pub fn is_stop_word ( word : & str ) -> bool {
895+ STOP_WORDS . contains ( word. to_lowercase ( ) . as_str ( ) )
896+ }
897+
898+ /// Legal punctuation in NSM explications.
899+ pub const LEGAL_PUNCTUATION : & [ char ] = & [ '\'' , '.' , ',' , ':' , '!' , '?' , '"' , '\n' , '\t' , '(' , ')' , '/' ] ;
900+
901+ // ── Evaluation types ────────────────────────────────────────────────────────
902+
903+ /// A single prediction from a grader model.
904+ #[ derive( Clone , Debug ) ]
905+ pub struct Prediction {
906+ pub prediction : String ,
907+ pub answer_logprob : f32 ,
908+ pub answer_ranks : Vec < usize > ,
909+ pub is_match : bool ,
910+ pub lines_removed : usize ,
911+ }
912+
913+ impl Prediction {
914+ pub fn new ( prediction : & str ) -> Self {
915+ Self { prediction : prediction. to_string ( ) , answer_logprob : 0.0 , answer_ranks : Vec :: new ( ) , is_match : false , lines_removed : 0 }
916+ }
917+ }
918+
919+ /// Substitutability score from one grader model.
920+ #[ derive( Clone , Debug ) ]
921+ pub struct SubstitutabilityScore {
922+ pub model : String ,
923+ pub baselines : Vec < Prediction > ,
924+ pub exp_baselines : Vec < Prediction > ,
925+ pub minimality : Vec < Vec < Prediction > > ,
926+ pub entailments : Vec < Vec < Prediction > > ,
927+ pub adj_score : f32 ,
928+ pub avg_delta_log : f32 ,
929+ pub avg_min_delta_log : f32 ,
930+ pub avg_ent_delta_log : f32 ,
931+ pub total_match : usize ,
932+ }
933+
934+ impl SubstitutabilityScore {
935+ pub fn new ( model : & str ) -> Self {
936+ Self { model : model. to_string ( ) , baselines : Vec :: new ( ) , exp_baselines : Vec :: new ( ) , minimality : Vec :: new ( ) , entailments : Vec :: new ( ) , adj_score : 0.0 , avg_delta_log : 0.0 , avg_min_delta_log : 0.0 , avg_ent_delta_log : 0.0 , total_match : 0 }
937+ }
938+ }
939+
940+ /// An NSM explication with legality scoring.
941+ #[ derive( Clone , Debug ) ]
942+ pub struct Explication {
943+ pub text : String ,
944+ pub target_word : String ,
945+ pub length : usize ,
946+ pub primes : usize ,
947+ pub stop_words_count : usize ,
948+ pub molecules : usize ,
949+ pub unique_molecules : usize ,
950+ pub uses_original_word : bool ,
951+ pub primes_ratio : f32 ,
952+ pub molecules_ratio : f32 ,
953+ pub sub_scores : Vec < SubstitutabilityScore > ,
954+ pub avg_delta : f32 ,
955+ pub avg_delta_min : f32 ,
956+ pub avg_delta_ent : f32 ,
957+ pub score_exp : f32 ,
958+ pub total_score : f32 ,
959+ }
960+
961+ impl Explication {
962+ pub fn new ( text : & str ) -> Self {
963+ Self {
964+ text : text. to_string ( ) , target_word : String :: new ( ) ,
965+ length : 0 , primes : 0 , stop_words_count : 0 , molecules : 0 ,
966+ unique_molecules : 0 , uses_original_word : false , primes_ratio : 0.0 ,
967+ molecules_ratio : 0.0 , sub_scores : Vec :: new ( ) , avg_delta : 0.0 ,
968+ avg_delta_min : 0.0 , avg_delta_ent : 0.0 , score_exp : 0.0 , total_score : 0.0 ,
969+ }
970+ }
971+
972+ /// Score legality against a target word (circularity via stem matching).
973+ pub fn legality_score ( & mut self , word : & str ) {
974+ let clean: String = self . text . to_lowercase ( ) . chars ( )
975+ . filter ( |c| c. is_alphanumeric ( ) || c. is_whitespace ( ) ) . collect ( ) ;
976+ let tokens: Vec < & str > = clean. split_whitespace ( ) . collect ( ) ;
977+ self . target_word = word. to_string ( ) ;
978+ self . length = tokens. len ( ) ;
979+ self . primes = tokens. iter ( ) . filter ( |t| is_nsm_prime ( t) ) . count ( ) ;
980+ self . stop_words_count = tokens. iter ( ) . filter ( |t| is_stop_word ( t) ) . count ( ) ;
981+ let mols: Vec < & & str > = tokens. iter ( ) . filter ( |t| !is_nsm_prime ( t) && !is_stop_word ( t) ) . collect ( ) ;
982+ self . molecules = mols. len ( ) ;
983+ self . unique_molecules = mols. iter ( ) . collect :: < HashSet < _ > > ( ) . len ( ) ;
984+ let wl = word. to_lowercase ( ) ;
985+ let stem = if wl. len ( ) >= 4 { & wl[ ..4 ] } else { & wl } ;
986+ self . uses_original_word = tokens. iter ( ) . any ( |t| * t == wl || ( t. len ( ) >= 4 && t. starts_with ( stem) ) ) ;
987+ self . primes_ratio = if self . length > 0 { self . primes as f32 / self . length as f32 } else { 0.0 } ;
988+ self . molecules_ratio = if self . length > 0 { self . molecules as f32 / self . length as f32 } else { 0.0 } ;
989+ }
990+
991+ /// Compute averages from substitutability sub-scores.
992+ pub fn calculate_averages ( & mut self ) {
993+ if self . sub_scores . is_empty ( ) { return ; }
994+ let n = self . sub_scores . len ( ) as f32 ;
995+ self . avg_delta = self . sub_scores . iter ( ) . map ( |s| s. avg_delta_log ) . sum :: < f32 > ( ) / n;
996+ self . avg_delta_min = self . sub_scores . iter ( ) . map ( |s| s. avg_min_delta_log ) . sum :: < f32 > ( ) / n;
997+ self . avg_delta_ent = self . sub_scores . iter ( ) . map ( |s| s. avg_ent_delta_log ) . sum :: < f32 > ( ) / n;
998+ self . score_exp = self . sub_scores . iter ( ) . map ( |s| s. adj_score ) . sum :: < f32 > ( ) / n;
999+ self . total_score = if !self . uses_original_word {
1000+ 2.0 * ( self . score_exp + 10.0 * self . primes_ratio - 10.0 * self . molecules_ratio )
1001+ } else { 0.0 } ;
1002+ }
1003+
1004+ /// Truncated versions with lines removed from the end.
1005+ pub fn get_truncated ( & self , max_lines_remove : usize ) -> Vec < Explication > {
1006+ let lines: Vec < & str > = self . text . trim ( ) . lines ( ) . collect ( ) ;
1007+ ( 0 ..max_lines_remove. min ( lines. len ( ) ) )
1008+ . map ( |i| Explication :: new ( & lines[ ..lines. len ( ) - ( i + 1 ) ] . join ( "\n " ) ) )
1009+ . collect ( )
1010+ }
1011+ }
1012+
1013+ /// Ambiguous example passage with masked word.
1014+ #[ derive( Clone , Debug ) ]
1015+ pub struct AmbiguousExample {
1016+ pub text : String ,
1017+ pub source : Option < String > ,
1018+ }
1019+
1020+ impl AmbiguousExample {
1021+ pub fn new ( text : & str ) -> Self { Self { text : text. to_string ( ) , source : None } }
1022+
1023+ /// Truncated versions removing non-UNK sentences.
1024+ pub fn get_truncated ( & self , max_remove : usize ) -> Vec < AmbiguousExample > {
1025+ let sents: Vec < & str > = self . text . split ( '.' ) . map ( |s| s. trim ( ) ) . filter ( |s| !s. is_empty ( ) ) . collect ( ) ;
1026+ let non_unk: Vec < usize > = sents. iter ( ) . enumerate ( ) . filter ( |( _, s) | !s. contains ( "<UNK>" ) ) . map ( |( i, _) | i) . collect ( ) ;
1027+ ( 0 ..max_remove. min ( non_unk. len ( ) ) ) . map ( |i| {
1028+ let exclude: HashSet < usize > = non_unk[ ..=i] . iter ( ) . copied ( ) . collect ( ) ;
1029+ let kept: Vec < & str > = sents. iter ( ) . enumerate ( ) . filter ( |( j, _) | !exclude. contains ( j) ) . map ( |( _, s) | * s) . collect ( ) ;
1030+ AmbiguousExample :: new ( & kept. join ( ". " ) )
1031+ } ) . collect ( )
1032+ }
1033+ }
1034+
1035+ /// Aggregated model evaluation result.
1036+ #[ derive( Clone , Debug ) ]
1037+ pub struct ModelResult {
1038+ pub model_name : String ,
1039+ pub num_examples : usize ,
1040+ pub explications : Vec < Explication > ,
1041+ pub avg_primes_ratio : f32 ,
1042+ pub avg_molecules_ratio : f32 ,
1043+ pub avg_total_score : f32 ,
1044+ }
1045+
1046+ impl ModelResult {
1047+ pub fn new ( model_name : & str ) -> Self {
1048+ Self { model_name : model_name. to_string ( ) , num_examples : 0 , explications : Vec :: new ( ) , avg_primes_ratio : 0.0 , avg_molecules_ratio : 0.0 , avg_total_score : 0.0 }
1049+ }
1050+ pub fn calculate_averages ( & mut self ) {
1051+ let n = self . explications . len ( ) as f32 ;
1052+ if n == 0.0 { return ; }
1053+ self . avg_primes_ratio = self . explications . iter ( ) . map ( |e| e. primes_ratio ) . sum :: < f32 > ( ) / n;
1054+ self . avg_molecules_ratio = self . explications . iter ( ) . map ( |e| e. molecules_ratio ) . sum :: < f32 > ( ) / n;
1055+ self . avg_total_score = self . explications . iter ( ) . map ( |e| e. total_score ) . sum :: < f32 > ( ) / n;
1056+ }
1057+ }
1058+
1059+ // ── CAM-PQ bridge ───────────────────────────────────────────────────────────
1060+
1061+ /// Load DeepNSM codebook (`codebook_pq.bin`) into ndarray's CamCodebook.
1062+ pub fn load_nsm_codebook ( codebook_bytes : & [ u8 ] ) -> super :: cam_pq:: CamCodebook {
1063+ use super :: cam_pq:: { CamCodebook , SubspaceCodebook , NUM_CENTROIDS , NUM_SUBSPACES } ;
1064+ let expected = NUM_SUBSPACES * NUM_CENTROIDS * 16 * 4 ;
1065+ assert_eq ! ( codebook_bytes. len( ) , expected, "codebook_pq.bin: expected {expected} bytes, got {}" , codebook_bytes. len( ) ) ;
1066+ let mut codebooks: Vec < SubspaceCodebook > = Vec :: with_capacity ( NUM_SUBSPACES ) ;
1067+ for s in 0 ..NUM_SUBSPACES {
1068+ let mut centroids = Vec :: with_capacity ( NUM_CENTROIDS ) ;
1069+ for c in 0 ..NUM_CENTROIDS {
1070+ let mut centroid = Vec :: with_capacity ( 16 ) ;
1071+ for d in 0 ..16 {
1072+ let off = ( s * NUM_CENTROIDS * 16 + c * 16 + d) * 4 ;
1073+ centroid. push ( f32:: from_le_bytes ( [ codebook_bytes[ off] , codebook_bytes[ off+1 ] , codebook_bytes[ off+2 ] , codebook_bytes[ off+3 ] ] ) ) ;
1074+ }
1075+ centroids. push ( centroid) ;
1076+ }
1077+ codebooks. push ( SubspaceCodebook { centroids, subspace_dim : 16 } ) ;
1078+ }
1079+ CamCodebook { codebooks : codebooks. try_into ( ) . unwrap ( ) , total_dim : 96 , subspace_dim : 16 }
1080+ }
1081+
1082+ /// Load CAM codes (`cam_codes.bin`): N words × 6 bytes.
1083+ pub fn load_cam_codes ( bytes : & [ u8 ] ) -> Vec < super :: cam_pq:: CamFingerprint > {
1084+ assert_eq ! ( bytes. len( ) % 6 , 0 ) ;
1085+ bytes. chunks_exact ( 6 ) . map ( |c| { let mut fp = [ 0u8 ; 6 ] ; fp. copy_from_slice ( c) ; fp } ) . collect ( )
1086+ }
1087+
1088+ // ── 36-bit SPO triple ───────────────────────────────────────────────────────
1089+
1090+ /// 36-bit SPO triple packed in u64. 12-bit subject + predicate + object.
1091+ #[ derive( Clone , Copy , PartialEq , Eq , Hash , Debug ) ]
1092+ pub struct SpoTriple { packed : u64 }
1093+
1094+ impl SpoTriple {
1095+ pub fn new ( subject : u16 , predicate : u16 , object : u16 ) -> Self {
1096+ debug_assert ! ( subject < 4096 && predicate < 4096 && object < 4096 ) ;
1097+ Self { packed : ( ( subject as u64 ) << 24 ) | ( ( predicate as u64 ) << 12 ) | object as u64 }
1098+ }
1099+ pub fn subject ( & self ) -> u16 { ( ( self . packed >> 24 ) & 0xFFF ) as u16 }
1100+ pub fn predicate ( & self ) -> u16 { ( ( self . packed >> 12 ) & 0xFFF ) as u16 }
1101+ pub fn object ( & self ) -> u16 { ( self . packed & 0xFFF ) as u16 }
1102+ }
1103+
1104+ // ── Prompt templates ────────────────────────────────────────────────────────
1105+
1106+ /// NSM explication system instruction.
1107+ pub const NSM_EXPLICATION_SYS_INST : & str = "You are a linguist specializing in semantic analysis using the Natural Semantic Metalanguage (NSM) approach. NSM reduces lexicons to universal semantic primes. Paraphrase the word's meaning using NSM primes." ;
1108+
1109+ /// Recovery prompt: predict masked word.
1110+ pub const RECOVERY_PROMPT_SYS_INST : & str = "Read the passage with a missing word indicated by <UNK>. Predict the missing word. Output only your prediction." ;
1111+
1112+ /// Chat message for prompt construction.
1113+ #[ derive( Clone , Debug ) ]
1114+ pub struct ChatMessage { pub role : String , pub content : String }
1115+
1116+ /// Build explication prompt with optional few-shot.
1117+ pub fn build_explication_prompt ( word : & str , examples : & [ & str ] , few_shot : & [ ( String , String ) ] , max : Option < usize > ) -> Vec < ChatMessage > {
1118+ let mut msgs = vec ! [ ChatMessage { role: "system" . into( ) , content: NSM_EXPLICATION_SYS_INST . into( ) } ] ;
1119+ for ( u, a) in & few_shot[ ..max. unwrap_or ( few_shot. len ( ) ) . min ( few_shot. len ( ) ) ] {
1120+ msgs. push ( ChatMessage { role : "user" . into ( ) , content : u. clone ( ) } ) ;
1121+ msgs. push ( ChatMessage { role : "assistant" . into ( ) , content : a. clone ( ) } ) ;
1122+ }
1123+ msgs. push ( ChatMessage { role : "user" . into ( ) , content : format ! ( "Word: {word}\n Examples:\n {}\n Paraphrase:" , examples. join( "\n \n " ) ) } ) ;
1124+ msgs
1125+ }
1126+
1127+ /// Build recovery prompt for substitutability testing.
1128+ pub fn build_recover_prompt ( ambig : & AmbiguousExample , exp : Option < & Explication > ) -> Vec < ChatMessage > {
1129+ let user = match exp {
1130+ Some ( e) => format ! ( "Passage: {}\n Paraphrase:\n {}\n Missing word:" , ambig. text, e. text) ,
1131+ None => format ! ( "Passage: {}\n Missing Word:" , ambig. text) ,
1132+ } ;
1133+ vec ! [
1134+ ChatMessage { role: "system" . into( ) , content: RECOVERY_PROMPT_SYS_INST . into( ) } ,
1135+ ChatMessage { role: "user" . into( ) , content: user } ,
1136+ ]
1137+ }
1138+
1139+ // ── Tests ───────────────────────────────────────────────────────────────────
1140+
1141+ #[ cfg( test) ]
1142+ mod eval_tests {
1143+ use super :: * ;
1144+
1145+ #[ test]
1146+ fn test_is_nsm_prime ( ) {
1147+ assert ! ( is_nsm_prime( "think" ) ) ;
1148+ assert ! ( is_nsm_prime( "THINK" ) ) ;
1149+ assert ! ( !is_nsm_prime( "journalism" ) ) ;
1150+ }
1151+
1152+ #[ test]
1153+ fn test_is_stop_word ( ) {
1154+ assert ! ( is_stop_word( "the" ) ) ;
1155+ assert ! ( !is_stop_word( "think" ) ) ; // NSM prime, not stopword
1156+ }
1157+
1158+ #[ test]
1159+ fn test_explication_legality ( ) {
1160+ let mut exp = Explication :: new ( "someone can feel something good because of this" ) ;
1161+ exp. legality_score ( "happy" ) ;
1162+ assert ! ( exp. primes_ratio > 0.5 ) ;
1163+ assert ! ( !exp. uses_original_word) ;
1164+ }
1165+
1166+ #[ test]
1167+ fn test_explication_circularity ( ) {
1168+ let mut exp = Explication :: new ( "feeling happy about something" ) ;
1169+ exp. legality_score ( "happy" ) ;
1170+ assert ! ( exp. uses_original_word) ;
1171+ }
1172+
1173+ #[ test]
1174+ fn test_explication_averages ( ) {
1175+ let mut exp = Explication :: new ( "test" ) ;
1176+ exp. primes_ratio = 0.6 ;
1177+ exp. molecules_ratio = 0.1 ;
1178+ let mut s = SubstitutabilityScore :: new ( "g" ) ;
1179+ s. adj_score = 5.0 ;
1180+ exp. sub_scores . push ( s) ;
1181+ exp. calculate_averages ( ) ;
1182+ assert ! ( exp. total_score > 0.0 ) ;
1183+ }
1184+
1185+ #[ test]
1186+ fn test_truncated ( ) {
1187+ let exp = Explication :: new ( "line one\n line two\n line three" ) ;
1188+ assert_eq ! ( exp. get_truncated( 2 ) . len( ) , 2 ) ;
1189+ }
1190+
1191+ #[ test]
1192+ fn test_ambiguous_truncated ( ) {
1193+ let a = AmbiguousExample :: new ( "The cat sat. The <UNK> was red. It was sunny." ) ;
1194+ let t = a. get_truncated ( 2 ) ;
1195+ assert ! ( !t. is_empty( ) ) ;
1196+ }
1197+
1198+ #[ test]
1199+ fn test_model_result ( ) {
1200+ let mut r = ModelResult :: new ( "m" ) ;
1201+ let mut e = Explication :: new ( "t" ) ;
1202+ e. primes_ratio = 0.5 ;
1203+ e. total_score = 8.0 ;
1204+ r. explications . push ( e) ;
1205+ r. calculate_averages ( ) ;
1206+ assert ! ( ( r. avg_total_score - 8.0 ) . abs( ) < 0.01 ) ;
1207+ }
1208+
1209+ #[ test]
1210+ fn test_spo_triple ( ) {
1211+ let t = SpoTriple :: new ( 671 , 2943 , 95 ) ;
1212+ assert_eq ! ( t. subject( ) , 671 ) ;
1213+ assert_eq ! ( t. predicate( ) , 2943 ) ;
1214+ assert_eq ! ( t. object( ) , 95 ) ;
1215+ }
1216+
1217+ #[ test]
1218+ fn test_cam_codes_load ( ) {
1219+ let bytes = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ] ;
1220+ let codes = load_cam_codes ( & bytes) ;
1221+ assert_eq ! ( codes. len( ) , 2 ) ;
1222+ assert_eq ! ( codes[ 0 ] , [ 1 , 2 , 3 , 4 , 5 , 6 ] ) ;
1223+ }
1224+
1225+ #[ test]
1226+ fn test_prompt_building ( ) {
1227+ let msgs = build_explication_prompt ( "happy" , & [ "I am happy" ] , & [ ] , None ) ;
1228+ assert_eq ! ( msgs. len( ) , 2 ) ;
1229+ assert ! ( msgs[ 1 ] . content. contains( "happy" ) ) ;
1230+ }
1231+ }
0 commit comments