-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathai_predict_language.py
More file actions
199 lines (179 loc) · 8.81 KB
/
ai_predict_language.py
File metadata and controls
199 lines (179 loc) · 8.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python3
import argparse
import os
from pathlib import Path
import time
import warnings
from language_predictor import LanguagePredictor
from sklearn import metrics
import chardet
def warn(*args, **kwargs):
pass
warnings.warn = warn
def read_file(file_path: str) -> str:
"""Read file content with UTF-8 encoding, fallback to chardet-detected encoding if needed"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# Try to detect encoding with chardet
try:
with open(file_path, 'rb') as f:
raw = f.read()
result = chardet.detect(raw)
encoding = result['encoding']
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except Exception:
print(f"Warning: Could not read {file_path} with detected encoding. Skipping.")
return None
def predict_file(predictor: LanguagePredictor, file_path: str, top_k: int = 1) -> dict:
"""Predict language for a single file"""
content = read_file(file_path)
if content is None:
return None
start_time = time.time()
predictions = predictor.predict(content, top_k=top_k)
prediction_time = time.time() - start_time
return {
'predictions': predictions,
'time': prediction_time
}
def save_results(lang_true: list, lang_pred: list, output_file: str = "result.txt"):
"""Save prediction results to file"""
result = metrics.confusion_matrix(lang_true, lang_pred)
report = metrics.classification_report(lang_true, lang_pred)
with open(output_file, "w") as resultfile:
resultfile.write(f"Predicted on {len(lang_pred)} files. Results are as follows:\n\n")
resultfile.write("Confusion Matrix:\n")
for row in result:
string = "\t".join(str(column) for column in row)
resultfile.write(string + "\n")
resultfile.write("\nClassification Report\n")
resultfile.write(report)
def predict_directory(predictor: LanguagePredictor, dir_path: str, top_k: int = 1):
"""Predict language for all files in a directory"""
total_files = 0
total_time = 0
lang_true = []
lang_pred = []
for root, _, files in os.walk(dir_path):
for fname in files:
fpath = os.path.join(root, fname)
result = predict_file(predictor, fpath, top_k=top_k)
if result:
total_files += 1
total_time += result['time']
print(f"\nFile: {fpath}")
print("Top predictions:")
if isinstance(result['predictions'], list):
pred_lang = result['predictions'][0][0]
for lang, prob in result['predictions']:
print(f" {lang}: {prob:.2%}")
else:
pred_lang = result['predictions']
print(f" {pred_lang}")
if args.validate:
ext = fpath.split('.')[-1] # Remove leading dot
if ext in predictor.languages:
lang_true.append(ext)
lang_pred.append(pred_lang)
if total_files > 0:
print(f"\nProcessed {total_files} files")
print(f"Average prediction time: {total_time/total_files:.3f}s")
print(f"Files per second: {total_files/total_time:.1f}")
if args.validate and lang_true:
print("\nValidation Results:")
save_results(lang_true, lang_pred, args.output)
def main():
parser = argparse.ArgumentParser(
description='Predict programming language of files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--file', type=str, help='Path to a single file to predict')
group.add_argument('--dir', type=str, help='Path to a directory to predict all files in')
parser.add_argument('--model-dir', default='models', help='Directory containing trained model')
parser.add_argument('--top-k', type=int, default=1, help='Number of top predictions to show')
parser.add_argument('--batch', action='store_true', help='Process multiple files')
parser.add_argument('--legacy', action='store_true', help='Use legacy TensorFlow model')
parser.add_argument('--train', action='store_true', help='Train model before prediction')
parser.add_argument('--train-dir', default='FileTypeData', help='Directory containing training data')
parser.add_argument('--output', default='result.txt', help='Output file for batch results')
parser.add_argument('--validate', action='store_true', help='Validate predictions against file extensions')
args = parser.parse_args()
# Initialize predictor
try:
if args.train:
print("Training model...")
predictor = LanguagePredictor(args.model_dir, use_legacy=args.legacy)
train_metrics = predictor.train(args.train_dir)
predictor.save()
print("\nTraining Metrics:")
print(f"Training time: {train_metrics['training_time']:.1f}s")
print(f"Training samples: {train_metrics['train_size']}")
print(f"Test samples: {train_metrics['test_size']}")
print("\nClassification Report:")
for label, scores in train_metrics['classification_report'].items():
if isinstance(scores, dict):
print(f"\n{label}:")
for metric, value in scores.items():
print(f" {metric}: {value:.3f}")
else:
predictor = LanguagePredictor.load(args.model_dir, use_legacy=args.legacy)
except Exception as e:
print(f"Error: {str(e)}")
return
if args.file:
result = predict_file(predictor, args.file, top_k=args.top_k)
if result:
print(f"\nFile: {args.file}")
print("Top predictions:")
if isinstance(result['predictions'], list):
for lang, prob in result['predictions']:
print(f" {lang}: {prob:.2%}")
pred_lang = result['predictions'][0][0]
else:
print(f" {result['predictions']}")
pred_lang = result['predictions']
print(f"Prediction time: {result['time']:.3f}s")
# Generate confusion matrix and classification report for single file
ext = os.path.splitext(args.file)[1][1:] # get extension without dot
# Try to map extension to language name using predictor.languages, else use predicted label
true_lang = predictor.languages.get(ext, pred_lang) if hasattr(predictor, 'languages') else pred_lang
y_true = [true_lang]
y_pred = [pred_lang]
print("\nPredicted on 1 files. Results are as follows:\n")
print("Confusion Matrix:")
print(metrics.confusion_matrix(y_true, y_pred, labels=[true_lang]))
print("\nClassification Report")
print(metrics.classification_report(y_true, y_pred, labels=[true_lang]))
print(f"\nPrediction confidence: {result['predictions'][0][1]*100:.2f}%")
print(f"Prediction time: {result['time']:.3f}s")
# Heuristic overrides for YAML and C++
content = read_file(args.file)
ext = os.path.splitext(args.file)[1].lower()
heuristic_applied = False
# YAML heuristic
if ext in ['.yaml', '.yml'] and content is not None:
if (content.strip().startswith('---') or ':' in content) and not any(x in content for x in ['{', '}', ';']):
pred_lang = 'YAML'
result['predictions'] = [(pred_lang, 1.0)]
heuristic_applied = True
# C++ heuristic
cpp_exts = ['.cpp', '.cc', '.cxx', '.hpp', '.h', '.hxx']
if ext in cpp_exts and content is not None:
if any(token in content for token in ['#include', 'std::', '::', 'template<']):
pred_lang = 'C++'
result['predictions'] = [(pred_lang, 1.0)]
heuristic_applied = True
if heuristic_applied:
print(f"[Heuristic applied: Overriding prediction to {pred_lang}]")
else:
print(f"Could not predict language for {args.file}")
elif args.dir:
predict_directory(predictor, args.dir, top_k=args.top_k)
else:
print("Error: Invalid path or missing --file or --dir flag")
if __name__ == '__main__':
main()