Skip to content

Commit e7b3f54

Browse files
authored
Detect partial stub packages (#13)
Work towards #11 This is a `semanal_pass1` feature that would be good to access without full deserialization, since it is needed early during graph loading.
1 parent 615b73c commit e7b3f54

2 files changed

Lines changed: 20 additions & 5 deletions

File tree

src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ fn parse(
4747
platform: Option<String>,
4848
always_true: Option<Vec<String>>,
4949
always_false: Option<Vec<String>>,
50-
) -> PyResult<(Vec<u8>, Vec<PyObject>, Vec<PyObject>, Vec<u8>)> {
50+
) -> PyResult<(Vec<u8>, Vec<PyObject>, Vec<PyObject>, Vec<u8>, bool)> {
5151
// Get defaults from Python if not provided
5252
let python_version = match python_version {
5353
Some(v) => v,
@@ -63,7 +63,7 @@ fn parse(
6363
let always_false = always_false.unwrap_or_default();
6464

6565
let path = Path::new(&fnam);
66-
let (ast_bytes, syntax_errors, type_ignore_lines, import_bytes) = py
66+
let (ast_bytes, syntax_errors, type_ignore_lines, import_bytes, is_partial_package) = py
6767
.allow_threads(|| {
6868
serialize_ast::serialize_python_file(
6969
path,
@@ -102,7 +102,7 @@ fn parse(
102102
.collect();
103103
let py_type_ignores = py_type_ignores?;
104104

105-
Ok((ast_bytes, py_errors, py_type_ignores, import_bytes))
105+
Ok((ast_bytes, py_errors, py_type_ignores, import_bytes, is_partial_package))
106106
}
107107

108108
/// Get the default Python version from sys.version_info

src/serialize_ast.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,14 @@ pub(crate) fn serialize_python_file(
171171
file_path: &Path,
172172
skip_function_bodies: bool,
173173
options: Options,
174-
) -> Result<(Vec<u8>, Vec<SyntaxError>, Vec<(usize, Vec<String>)>, Vec<u8>)> {
174+
) -> Result<(Vec<u8>, Vec<SyntaxError>, Vec<(usize, Vec<String>)>, Vec<u8>, bool)> {
175175
let source_type = PySourceType::from(file_path);
176176
let source_text = std::fs::read_to_string(file_path)?;
177177
let line_index = LineIndex::from_source_text(&source_text);
178+
let is_stub_package = match file_path.file_name() {
179+
Some(file) => file.as_encoded_bytes() == b"__init__.pyi",
180+
_ => false,
181+
};
178182

179183
// Check if file is all ASCII and build per-line non-ASCII flags if needed
180184
let is_all_ascii = source_text.is_ascii();
@@ -221,6 +225,7 @@ pub(crate) fn serialize_python_file(
221225
options,
222226
current_unreachable: false,
223227
current_mypy_only: false,
228+
top_level_getattr: false,
224229
};
225230
parsed.syntax().serialize(&mut ser);
226231

@@ -233,7 +238,10 @@ pub(crate) fn serialize_python_file(
233238
Some(ser.lines_with_non_ascii),
234239
);
235240

236-
Ok((ser.bytes, syntax_errors, type_ignore_lines, import_bytes))
241+
// Return this directly to caller, so that it can check this without deserialization
242+
let is_partial_package = is_stub_package && ser.top_level_getattr;
243+
244+
Ok((ser.bytes, syntax_errors, type_ignore_lines, import_bytes, is_partial_package))
237245
}
238246

239247
// Bit flags for import statement metadata
@@ -279,6 +287,7 @@ struct Serializer<'a> {
279287
options: Options, // Reachability analysis options
280288
current_unreachable: bool, // Whether we're currently in an unreachable block
281289
current_mypy_only: bool, // Whether we're currently in a mypy-only block (e.g., if TYPE_CHECKING)
290+
top_level_getattr: bool, // Does module have top-level __getattr__() function
282291
}
283292

284293
impl<'a> Serializer<'a> {
@@ -864,6 +873,10 @@ impl Ser for ast::Stmt {
864873
true
865874
};
866875

876+
if !ser.in_class && !ser.in_function && f.name.as_str() == "__getattr__" {
877+
ser.top_level_getattr = true;
878+
};
879+
867880
// Body - mark that we're inside a function
868881
let was_in_function = ser.in_function;
869882
ser.in_function = true;
@@ -2580,6 +2593,7 @@ pub fn serialize_imports(
25802593
options: Options::default(),
25812594
current_unreachable: false,
25822595
current_mypy_only: false,
2596+
top_level_getattr: false,
25832597
};
25842598

25852599
// Write list of imports
@@ -2708,6 +2722,7 @@ mod tests {
27082722
options: Options::default(),
27092723
current_unreachable: false,
27102724
current_mypy_only: false,
2725+
top_level_getattr: false,
27112726
}
27122727
}
27132728

0 commit comments

Comments
 (0)