Skip to content

Commit 84a2bf6

Browse files
Address feedback from PR #241 (#244)
I added a function called `Token.with_source` to simplify the repeated logic in the source extent tests, but if it's unnecessary I can easily revert those commits. Edit: Linking the PR for easy, future reference: #241.
1 parent 5aebe04 commit 84a2bf6

2 files changed

Lines changed: 101 additions & 134 deletions

File tree

scrapscript.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import base64
55
import code
6+
import copy
67
import dataclasses
78
import enum
89
import functools
@@ -15,7 +16,6 @@
1516
import urllib.request
1617
from dataclasses import dataclass
1718
from enum import auto
18-
from functools import reduce
1919
from types import ModuleType
2020
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, Set, Tuple, Union
2121

@@ -59,6 +59,10 @@ def join_source_extents(
5959
class Token:
6060
source_extent: SourceExtent = dataclasses.field(default_factory=SourceExtent, init=False, compare=False)
6161

62+
def with_source(self, source_extent: SourceExtent) -> Token:
63+
self.source_extent = source_extent
64+
return self
65+
6266

6367
@dataclass(eq=True)
6468
class IntLit(Token):
@@ -213,18 +217,7 @@ def peek_char(self) -> str:
213217

214218
def make_token(self, cls: type, *args: Any) -> Token:
215219
result: Token = cls(*args)
216-
217-
# Set start of token's source extent
218-
result.source_extent.start.lineno = self.current_token_source_extent.start.lineno
219-
result.source_extent.start.colno = self.current_token_source_extent.start.colno
220-
result.source_extent.start.byteno = self.current_token_source_extent.start.byteno
221-
222-
# Set end of token's source extent
223-
result.source_extent.end.colno = self.current_token_source_extent.end.colno
224-
result.source_extent.end.lineno = self.current_token_source_extent.end.lineno
225-
result.source_extent.end.byteno = self.current_token_source_extent.end.byteno
226-
227-
return result
220+
return result.with_source(copy.deepcopy(self.current_token_source_extent))
228221

229222
def read_tokens(self) -> Generator[Token, None, None]:
230223
while (token := self.read_token()) and not isinstance(token, EOF):
@@ -568,6 +561,7 @@ def parse_unary(tokens: Peekable, p: float) -> "Object":
568561
MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body
569562
)
570563
cases = [match_case]
564+
match_function_source_extent = match_case.source_extent
571565
while True:
572566
try:
573567
if tokens.peek() != Operator("|"):
@@ -582,13 +576,10 @@ def parse_unary(tokens: Peekable, p: float) -> "Object":
582576
MatchCase, pipe_source_extent.coalesce(expr.source_extent), expr.arg, expr.body
583577
)
584578
cases.append(match_case)
585-
cases_source_extents = [case_branch.source_extent for case_branch in cases]
579+
match_function_source_extent = join_source_extents(match_function_source_extent, match_case.source_extent)
586580
return make_source_annotated_object(
587581
MatchFunction,
588-
reduce(
589-
join_source_extents,
590-
cases_source_extents,
591-
),
582+
match_function_source_extent,
592583
cases,
593584
)
594585
elif isinstance(token, LeftParen):

scrapscript_tests.py

Lines changed: 92 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import base64
21
import unittest
32
import re
43
from typing import Optional
@@ -1201,172 +1200,149 @@ def test_apply_with_variant_args(self) -> None:
12011200
self.assertEqual(ast, Apply(Apply(Var("f"), TRUE), FALSE))
12021201

12031202
def test_parse_int_preserves_source_extent(self) -> None:
1204-
int_lit = IntLit(1)
12051203
source_extent = SourceExtent(
12061204
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
12071205
)
1208-
int_lit.source_extent = source_extent
1209-
int_ast = make_source_annotated_object(Int, source_extent, 1)
1210-
self.assertTrue(parse(Peekable(iter([int_lit]))).source_extent == int_ast.source_extent)
1206+
int_lit = IntLit(1).with_source(source_extent)
1207+
self.assertEqual(parse(Peekable(iter([int_lit]))).source_extent, source_extent)
12111208

12121209
def test_parse_float_preserves_source_extent(self) -> None:
1213-
float_lit = FloatLit(3.2)
12141210
source_extent = SourceExtent(
12151211
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
12161212
)
1217-
float_lit.source_extent = source_extent
1218-
float_ast = make_source_annotated_object(Float, source_extent, 3.2)
1219-
self.assertTrue(parse(Peekable(iter([float_lit]))).source_extent == float_ast.source_extent)
1213+
float_lit = FloatLit(3.2).with_source(source_extent)
1214+
self.assertEqual(parse(Peekable(iter([float_lit]))).source_extent, source_extent)
12201215

12211216
def test_parse_string_preserves_source_extent(self) -> None:
1222-
string_lit = StringLit("Hello")
12231217
source_extent = SourceExtent(
12241218
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=7, byteno=6)
12251219
)
1226-
string_lit.source_extent = source_extent
1227-
string_ast = make_source_annotated_object(String, source_extent, "Hello")
1228-
self.assertTrue(parse(Peekable(iter([string_lit]))).source_extent == string_ast.source_extent)
1220+
string_lit = StringLit("Hello").with_source(source_extent)
1221+
self.assertEqual(parse(Peekable(iter([string_lit]))).source_extent, source_extent)
12291222

12301223
def test_parse_bytes_preserves_source_extent(self) -> None:
1231-
bytes_lit = BytesLit("QUJD", 64)
12321224
source_extent = SourceExtent(
12331225
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=9, byteno=8)
12341226
)
1235-
bytes_lit.source_extent = source_extent
1236-
bytes_ast = make_source_annotated_object(Bytes, source_extent, base64.b64decode("QUJD"))
1237-
self.assertTrue(parse(Peekable(iter([bytes_lit]))).source_extent == bytes_ast.source_extent)
1227+
bytes_lit = BytesLit("QUJD", 64).with_source(source_extent)
1228+
self.assertEqual(parse(Peekable(iter([bytes_lit]))).source_extent, source_extent)
12381229

12391230
def test_parse_var_preserves_source_extent(self) -> None:
1240-
var = Name("x")
12411231
source_extent = SourceExtent(
12421232
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
12431233
)
1244-
var.source_extent = source_extent
1245-
var_ast = make_source_annotated_object(Var, source_extent, "x")
1246-
self.assertTrue(parse(Peekable(iter([var]))).source_extent == var_ast.source_extent)
1234+
var = Name("x").with_source(source_extent)
1235+
self.assertEqual(parse(Peekable(iter([var]))).source_extent, source_extent)
12471236

12481237
def test_parse_hole_preserves_source_extent(self) -> None:
1249-
left_paren = LeftParen()
1250-
left_paren.source_extent = SourceExtent(
1251-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1252-
)
1253-
right_paren = RightParen()
1254-
right_paren.source_extent = SourceExtent(
1255-
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1238+
left_paren = LeftParen().with_source(
1239+
SourceExtent(
1240+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1241+
)
12561242
)
1257-
hole = make_source_annotated_object(
1258-
Hole,
1243+
right_paren = RightParen().with_source(
12591244
SourceExtent(
1260-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=2, byteno=1)
1261-
),
1245+
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1246+
)
1247+
)
1248+
hole_source_extent = SourceExtent(
1249+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=2, byteno=1)
12621250
)
1263-
self.assertTrue(parse(Peekable(iter([left_paren, right_paren]))).source_extent == hole.source_extent)
1251+
self.assertEqual(parse(Peekable(iter([left_paren, right_paren]))).source_extent, hole_source_extent)
12641252

12651253
def test_parenthesized_expression_preserves_source_extent(self) -> None:
1266-
left_paren = LeftParen()
1267-
left_paren.source_extent = SourceExtent(
1268-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1269-
)
1270-
int_lit = IntLit(1)
1271-
int_lit.source_extent = SourceExtent(
1272-
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1254+
left_paren = LeftParen().with_source(
1255+
SourceExtent(
1256+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1257+
)
12731258
)
1274-
right_paren = RightParen()
1275-
right_paren.source_extent = SourceExtent(
1276-
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1259+
int_lit = IntLit(1).with_source(
1260+
SourceExtent(
1261+
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1262+
)
12771263
)
1278-
parenthesized_int_lit = make_source_annotated_object(
1279-
Int,
1264+
right_paren = RightParen().with_source(
12801265
SourceExtent(
1281-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1282-
),
1283-
1,
1266+
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1267+
)
12841268
)
1285-
self.assertTrue(
1286-
parse(Peekable(iter([left_paren, int_lit, right_paren]))).source_extent
1287-
== parenthesized_int_lit.source_extent
1269+
parenthesized_int_lit_source_extent = SourceExtent(
1270+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1271+
)
1272+
self.assertEqual(
1273+
parse(Peekable(iter([left_paren, int_lit, right_paren]))).source_extent, parenthesized_int_lit_source_extent
12881274
)
12891275

12901276
def test_parse_spread_preserves_source_extent(self) -> None:
1291-
ellipsis = Operator("...")
1292-
ellipsis.source_extent = SourceExtent(
1293-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1294-
)
1295-
name = Name("x")
1296-
name.source_extent = SourceExtent(
1297-
start=SourceLocation(lineno=1, colno=4, byteno=3), end=SourceLocation(lineno=1, colno=4, byteno=3)
1277+
ellipsis = Operator("...").with_source(
1278+
SourceExtent(
1279+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=3, byteno=2)
1280+
)
12981281
)
1299-
spread = make_source_annotated_object(
1300-
Spread,
1282+
name = Name("x").with_source(
13011283
SourceExtent(
1302-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=4, byteno=3)
1303-
),
1304-
"x",
1284+
start=SourceLocation(lineno=1, colno=4, byteno=3), end=SourceLocation(lineno=1, colno=4, byteno=3)
1285+
)
13051286
)
1306-
self.assertTrue(parse(Peekable(iter([ellipsis, name]))).source_extent == spread.source_extent)
1287+
spread_source_extent = SourceExtent(
1288+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=4, byteno=3)
1289+
)
1290+
self.assertEqual(parse(Peekable(iter([ellipsis, name]))).source_extent, spread_source_extent)
13071291

13081292
def test_parse_binop_preserves_source_extent(self) -> None:
1309-
first_addend = IntLit(1)
1310-
first_addend.source_extent = SourceExtent(
1311-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1312-
)
1313-
operator = Operator("+")
1314-
operator.source_extent = SourceExtent(
1315-
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1293+
first_addend = IntLit(1).with_source(
1294+
SourceExtent(
1295+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1296+
)
13161297
)
1317-
second_addend = IntLit(2)
1318-
second_addend.source_extent = SourceExtent(
1319-
start=SourceLocation(lineno=2, colno=5, byteno=4), end=SourceLocation(lineno=2, colno=5, byteno=4)
1298+
operator = Operator("+").with_source(
1299+
SourceExtent(
1300+
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1301+
)
13201302
)
1321-
first_addend_ast = make_source_annotated_object(Int, first_addend.source_extent, 1)
1322-
operator_ast = BinopKind.ADD
1323-
second_addend_ast = make_source_annotated_object(Int, second_addend.source_extent, 2)
1324-
binop = make_source_annotated_object(
1325-
Binop,
1303+
second_addend = IntLit(2).with_source(
13261304
SourceExtent(
1327-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=2, colno=5, byteno=4)
1328-
),
1329-
operator_ast,
1330-
first_addend_ast,
1331-
second_addend_ast,
1305+
start=SourceLocation(lineno=2, colno=5, byteno=4), end=SourceLocation(lineno=2, colno=5, byteno=4)
1306+
)
13321307
)
1333-
self.assertTrue(
1334-
parse(Peekable(iter([first_addend, operator, second_addend]))).source_extent == binop.source_extent
1308+
binop_source_extent = SourceExtent(
1309+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=2, colno=5, byteno=4)
1310+
)
1311+
self.assertEqual(
1312+
parse(Peekable(iter([first_addend, operator, second_addend]))).source_extent, binop_source_extent
13351313
)
13361314

13371315
def test_parse_list_preserves_source_extent(self) -> None:
1338-
left_bracket = LeftBracket()
1339-
left_bracket.source_extent = SourceExtent(
1340-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1341-
)
1342-
one = IntLit(1)
1343-
one.source_extent = SourceExtent(
1344-
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1316+
left_bracket = LeftBracket().with_source(
1317+
SourceExtent(
1318+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=1, byteno=0)
1319+
)
13451320
)
1346-
comma = Operator(",")
1347-
comma.source_extent = SourceExtent(
1348-
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1321+
one = IntLit(1).with_source(
1322+
SourceExtent(
1323+
start=SourceLocation(lineno=1, colno=2, byteno=1), end=SourceLocation(lineno=1, colno=2, byteno=1)
1324+
)
13491325
)
1350-
two = IntLit(2)
1351-
two.source_extent = SourceExtent(
1352-
start=SourceLocation(lineno=1, colno=5, byteno=4), end=SourceLocation(lineno=1, colno=5, byteno=4)
1326+
comma = Operator(",").with_source(
1327+
SourceExtent(
1328+
start=SourceLocation(lineno=1, colno=3, byteno=2), end=SourceLocation(lineno=1, colno=3, byteno=2)
1329+
)
13531330
)
1354-
right_bracket = RightBracket()
1355-
right_bracket.source_extent = SourceExtent(
1356-
start=SourceLocation(lineno=1, colno=6, byteno=5), end=SourceLocation(lineno=1, colno=6, byteno=5)
1331+
two = IntLit(2).with_source(
1332+
SourceExtent(
1333+
start=SourceLocation(lineno=1, colno=5, byteno=4), end=SourceLocation(lineno=1, colno=5, byteno=4)
1334+
)
13571335
)
1358-
one_ast = make_source_annotated_object(Int, one.source_extent, 1)
1359-
two_ast = make_source_annotated_object(Int, two.source_extent, 2)
1360-
list_ast = make_source_annotated_object(
1361-
List,
1336+
right_bracket = RightBracket().with_source(
13621337
SourceExtent(
1363-
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=6, byteno=5)
1364-
),
1365-
[one_ast, two_ast],
1338+
start=SourceLocation(lineno=1, colno=6, byteno=5), end=SourceLocation(lineno=1, colno=6, byteno=5)
1339+
)
13661340
)
1367-
self.assertTrue(
1368-
parse(Peekable(iter([left_bracket, one, comma, two, right_bracket]))).source_extent
1369-
== list_ast.source_extent
1341+
list_source_extent = SourceExtent(
1342+
start=SourceLocation(lineno=1, colno=1, byteno=0), end=SourceLocation(lineno=1, colno=6, byteno=5)
1343+
)
1344+
self.assertEqual(
1345+
parse(Peekable(iter([left_bracket, one, comma, two, right_bracket]))).source_extent, list_source_extent
13701346
)
13711347

13721348

@@ -2620,7 +2596,7 @@ def test_eval_count_bits_function_preserves_source_extents(self) -> None:
26202596
match_function_three_source_extent,
26212597
]
26222598

2623-
self.assertTrue(outer_function.source_extent == outer_function_source_extent)
2599+
self.assertEqual(outer_function.source_extent, outer_function_source_extent)
26242600
self.assertTrue(
26252601
all(
26262602
match_function.source_extent == source_extent
@@ -2672,9 +2648,9 @@ def test_eval_collatz_function_preserves_source_extents(self) -> None:
26722648
start=SourceLocation(lineno=4, colno=34, byteno=84), end=SourceLocation(lineno=5, colno=79, byteno=205)
26732649
)
26742650

2675-
self.assertTrue(outer_function.source_extent == outer_function_source_extent)
2676-
self.assertTrue(arg.source_extent == arg_source_extent)
2677-
self.assertTrue(func.source_extent == func_source_extent)
2651+
self.assertEqual(outer_function.source_extent, outer_function_source_extent)
2652+
self.assertEqual(arg.source_extent, arg_source_extent)
2653+
self.assertEqual(func.source_extent, func_source_extent)
26782654

26792655

26802656
class ClosureOptimizeTests(unittest.TestCase):

0 commit comments

Comments
 (0)