11# SPDX-License-Identifier: MIT
2- # Copyright (c) 2022 The Pybricks Authors
2+ # Copyright (c) 2022-2026 The Pybricks Authors
33
44
55import contextlib
66import os
77import struct
88import sys
99from tempfile import TemporaryDirectory
10+ from typing import Any
1011
1112import pytest
1213
1617if sys .version_info < (3 , 11 ):
1718 from contextlib import AbstractContextManager
1819
19- class chdir (AbstractContextManager ):
20+ class chdir (AbstractContextManager [ None ] ):
2021 """Non thread-safe context manager to change the current working directory."""
2122
22- def __init__ (self , path ) :
23+ def __init__ (self , path : str ) -> None :
2324 self .path = path
24- self ._old_cwd = []
25+ self ._old_cwd : list [ str ] = []
2526
26- def __enter__ (self ):
27+ def __enter__ (self ) -> None :
2728 self ._old_cwd .append (os .getcwd ())
2829 os .chdir (self .path )
2930
30- def __exit__ (self , * excinfo ) :
31+ def __exit__ (self , * excinfo : Any ) -> None :
3132 os .chdir (self ._old_cwd .pop ())
3233
3334 setattr (contextlib , "chdir" , chdir )
@@ -73,6 +74,8 @@ async def test_compile_multi_file(abi: int):
7374 "import test1\n " ,
7475 "from test2 import thing2\n " ,
7576 "from nested.test3 import thing3\n " ,
77+ "from test4 import thing4\n " ,
78+ "from nested.test5 import thing5\n " ,
7679 ]
7780 )
7881
@@ -100,6 +103,24 @@ async def test_compile_multi_file(abi: int):
100103 ) as f3 :
101104 f3 .write ("thing3 = 'thing3'\n " )
102105
106+ # test4 and test5 are to test package modules with non-empty __init__.py
107+
108+ os .mkdir ("test4" )
109+
110+ with open (
111+ os .path .join (temp_dir , "test4" , "__init__.py" ), "w" , encoding = "utf-8"
112+ ) as f4 :
113+ f4 .write ("thing4 = 'thing4'\n " )
114+
115+ os .mkdir (os .path .join ("nested" , "test5" ))
116+
117+ with open (
118+ os .path .join (temp_dir , "nested" , "test5" , "__init__.py" ),
119+ "w" ,
120+ encoding = "utf-8" ,
121+ ) as f5 :
122+ f5 .write ("thing5 = 'thing5'\n " )
123+
103124 multi_mpy = await compile_multi_file ("test.py" , abi )
104125 pos = 0
105126
@@ -130,13 +151,21 @@ def unpack_mpy(data: bytes) -> tuple[bytes, bytes]:
130151 names .add (name2 .decode ())
131152 name3 , mpy3 = unpack_mpy (multi_mpy )
132153 names .add (name3 .decode ())
154+
133155 if uses_module_finder :
134156 # ModuleFinder requires __init__.py.
135157 name4 , mpy4 = unpack_mpy (multi_mpy )
136158 names .add (name4 .decode ())
159+
137160 name5 , mpy5 = unpack_mpy (multi_mpy )
138161 names .add (name5 .decode ())
139162
163+ name6 , mpy6 = unpack_mpy (multi_mpy )
164+ names .add (name6 .decode ())
165+
166+ name7 , mpy7 = unpack_mpy (multi_mpy )
167+ names .add (name7 .decode ())
168+
140169 assert pos == len (multi_mpy )
141170
142171 # It is important that the main module is first.
@@ -150,9 +179,9 @@ def unpack_mpy(data: bytes) -> tuple[bytes, bytes]:
150179 assert "nested.test3" in names
151180
152181 if uses_module_finder :
153- assert len (names ) == 4
182+ assert len (names ) == 6
154183 else :
155- assert len (names ) == 3
184+ assert len (names ) == 5
156185
157186 def check_mpy (mpy : bytes ) -> None :
158187 magic , abi_ver , flags , int_bits = struct .unpack_from ("<BBBB" , mpy )
@@ -168,3 +197,5 @@ def check_mpy(mpy: bytes) -> None:
168197 if uses_module_finder :
169198 check_mpy (mpy4 ) # pyright: ignore[reportPossiblyUnboundVariable]
170199 check_mpy (mpy5 )
200+ check_mpy (mpy6 )
201+ check_mpy (mpy7 )
0 commit comments