Skip to content

Commit 0b5d0fa

Browse files
committed
Check for invalid characters in UnknownUnit
Current the only invalid characters are Space, /, and ^. I've also refactored the argument parsing to remove duplication between the numerator and denominator.
1 parent 84db638 commit 0b5d0fa

3 files changed

Lines changed: 106 additions & 61 deletions

File tree

sasdata/quantities/_units_base.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33
from fractions import Fraction
44
from typing import Self
5+
import re
56

67
import numpy as np
78
from unicode_superscript import int_as_unicode_superscript
@@ -326,40 +327,57 @@ class UnknownUnit(NamedUnit):
326327

327328
def __init__(self,
328329
numerator: str | list[str] | dict[str, int],
329-
denominator: None | list[str] | dict[str, int]= None):
330-
match numerator:
331-
case str():
332-
self._numerator = {numerator: 1}
333-
case list():
334-
self._numerator = {}
335-
for key in numerator:
336-
if key in self._numerator:
337-
self._numerator[key] += 1
338-
else:
339-
self._numerator[key] = 1
340-
case dict():
341-
self._numerator = numerator
342-
case _:
343-
raise TypeError
344-
match denominator:
330+
denominator: None | list[str] | dict[str, int] = None):
331+
if numerator is None:
332+
return TypeError
333+
self._numerator = UnknownUnit._parse_arg(numerator)
334+
self._denominator = UnknownUnit._parse_arg(denominator)
335+
self._unit = NamedUnit(1, Dimensions(), "") # Unitless
336+
337+
super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name())
338+
339+
@staticmethod
340+
def _parse_arg(arg: str | list[str] | dict[str, int]):
341+
"""Parse the different possibilities for constructor arguments
342+
343+
Both the numerator and the denominator could be a string, a
344+
list of strings, or a dict. Parse any of these values into a
345+
dictionary of names and powers.
346+
347+
"""
348+
match arg:
345349
case None:
346-
self._denominator = {}
350+
return {}
347351
case str():
348-
self._denominator = {denominator: 1}
352+
return {UnknownUnit._valid_name(arg): 1}
349353
case list():
350-
self._denominator = {}
351-
for key in denominator:
352-
if key in self._denominator:
353-
self._denominator[key] += 1
354+
result = {}
355+
for key in arg:
356+
if key in result:
357+
result[key] += 1
354358
else:
355-
self._denominator[key] = 1
359+
UnknownUnit._valid_name(key)
360+
result[key] = 1
361+
return result
356362
case dict():
357-
self._denominator = denominator
363+
for key in arg:
364+
UnknownUnit._valid_name(key)
365+
return arg
358366
case _:
359367
raise TypeError
360-
self._unit = NamedUnit(1, Dimensions(), "") # Unitless
361368

362-
super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name())
369+
@staticmethod
370+
def _valid_name(name: str) -> str:
371+
"""Confirms that the name of a unit is appropriate
372+
373+
This mostly confirms that the unit does not contain math
374+
operators that would act on other units, like / or ^
375+
"""
376+
377+
if re.search(r"[*/^\s]", name):
378+
raise RuntimeError(f'Unit name "{name}" contains invalid characters (*, /, ^, or whitespace)')
379+
380+
return name
363381

364382
def _name(self):
365383
num = []
@@ -482,7 +500,6 @@ def __pow__(self, power: int):
482500
case _:
483501
return NotImplemented
484502

485-
486503
def equivalent(self: Self, other: "Unit"):
487504
match other:
488505
case UnknownUnit():

sasdata/quantities/units.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from dataclasses import dataclass
8787
from fractions import Fraction
8888
from typing import Self
89+
import re
8990

9091
import numpy as np
9192

@@ -393,13 +394,12 @@ def __eq__(self, other):
393394
case _:
394395
return False
395396

396-
397397
def startswith(self, prefix: str) -> bool:
398398
"""Check if any representation of the unit begins with the prefix string"""
399399
prefix = prefix.lower()
400400
return (self.name is not None and self.name.lower().startswith(prefix)) \
401-
or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \
402-
or (self.symbol is not None and self.symbol.lower().startswith(prefix))
401+
or (self.ascii_symbol is not None and self.ascii_symbol.lower().startswith(prefix)) \
402+
or (self.symbol is not None and self.symbol.lower().startswith(prefix))
403403

404404

405405
class UnknownUnit(NamedUnit):
@@ -412,40 +412,57 @@ class UnknownUnit(NamedUnit):
412412

413413
def __init__(self,
414414
numerator: str | list[str] | dict[str, int],
415-
denominator: None | list[str] | dict[str, int]= None):
416-
match numerator:
417-
case str():
418-
self._numerator = {numerator: 1}
419-
case list():
420-
self._numerator = {}
421-
for key in numerator:
422-
if key in self._numerator:
423-
self._numerator[key] += 1
424-
else:
425-
self._numerator[key] = 1
426-
case dict():
427-
self._numerator = numerator
428-
case _:
429-
raise TypeError
430-
match denominator:
415+
denominator: None | list[str] | dict[str, int] = None):
416+
if numerator is None:
417+
return TypeError
418+
self._numerator = UnknownUnit._parse_arg(numerator)
419+
self._denominator = UnknownUnit._parse_arg(denominator)
420+
self._unit = NamedUnit(1, Dimensions(), "") # Unitless
421+
422+
super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name())
423+
424+
@staticmethod
425+
def _parse_arg(arg: str | list[str] | dict[str, int]):
426+
"""Parse the different possibilities for constructor arguments
427+
428+
Both the numerator and the denominator could be a string, a
429+
list of strings, or a dict. Parse any of these values into a
430+
dictionary of names and powers.
431+
432+
"""
433+
match arg:
431434
case None:
432-
self._denominator = {}
435+
return {}
433436
case str():
434-
self._denominator = {denominator: 1}
437+
return {UnknownUnit._valid_name(arg): 1}
435438
case list():
436-
self._denominator = {}
437-
for key in denominator:
438-
if key in self._denominator:
439-
self._denominator[key] += 1
439+
result = {}
440+
for key in arg:
441+
if key in result:
442+
result[key] += 1
440443
else:
441-
self._denominator[key] = 1
444+
UnknownUnit._valid_name(key)
445+
result[key] = 1
446+
return result
442447
case dict():
443-
self._denominator = denominator
448+
for key in arg:
449+
UnknownUnit._valid_name(key)
450+
return arg
444451
case _:
445452
raise TypeError
446-
self._unit = NamedUnit(1, Dimensions(), "") # Unitless
447453

448-
super().__init__(si_scaling_factor=1, dimensions=self._unit.dimensions, symbol=self._name())
454+
@staticmethod
455+
def _valid_name(name: str) -> str:
456+
"""Confirms that the name of a unit is appropriate
457+
458+
This mostly confirms that the unit does not contain math
459+
operators that would act on other units, like / or ^
460+
"""
461+
462+
if re.search(r"[*/^\s]", name):
463+
raise RuntimeError(f'Unit name "{name}" contains invalid characters (*, /, ^, or whitespace)')
464+
465+
return name
449466

450467
def _name(self):
451468
num = []
@@ -479,7 +496,6 @@ def __eq__(self, other):
479496
case Unit():
480497
return not self._numerator and not self._denominator and self._unit == other
481498

482-
483499
def __mul__(self: Self, other: "Unit"):
484500
match other:
485501
case UnknownUnit():
@@ -569,7 +585,6 @@ def __pow__(self, power: int):
569585
case _:
570586
return NotImplemented
571587

572-
573588
def equivalent(self: Self, other: "Unit"):
574589
match other:
575590
case UnknownUnit():

test/quantities/utest_units.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,24 @@ def test_unit_dissimilar(dissimilar_term):
9999

100100
def test_unit_names():
101101
pizza = UnknownUnit(["Pizza"])
102-
slice = UnknownUnit(["Slice"])
103-
pineapple = UnknownUnit(["Pineapple"])
104-
pie = UnknownUnit(["Pie"])
102+
slice = UnknownUnit("Slice")
103+
pineapple = UnknownUnit("Pineapple")
104+
pie = UnknownUnit("Pie")
105105
empty = UnknownUnit([])
106106

107+
with pytest.raises(RuntimeError):
108+
UnknownUnit("a/b")
109+
with pytest.raises(RuntimeError):
110+
UnknownUnit(["a^b"])
111+
with pytest.raises(RuntimeError):
112+
UnknownUnit({"a b": 1})
113+
with pytest.raises(RuntimeError):
114+
UnknownUnit("a", {"a*b": 1})
115+
with pytest.raises(RuntimeError):
116+
UnknownUnit("a", ["a^b"])
117+
with pytest.raises(RuntimeError):
118+
UnknownUnit("a", "a/b")
119+
107120
assert str(empty) == ""
108121

109122
assert str(pizza) == "Pizza"

0 commit comments

Comments
 (0)