Skip to content

Commit f10a6e1

Browse files
committed
Add flatmap to Result
1 parent c606ef4 commit f10a6e1

2 files changed

Lines changed: 74 additions & 10 deletions

File tree

src/fieldenum/enums.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import Unit, Variant, fieldenum, unreachable
2828
from .exceptions import IncompatibleBoundError, UnwrapFailedError
2929

30-
__all__ = ["Option", "BoundResult", "Message", "Some", "Success", "Failed"]
30+
__all__ = ["Option", "BoundResult", "Message", "Some", "Success", "Failed", "Result", "Ok", "Err"]
3131

3232
_MISSING = object()
3333
type _ExceptionTypes = type[BaseException] | tuple[type[BaseException], ... ] | UnionType
@@ -347,15 +347,15 @@ def dump(self) -> tuple[E]: ...
347347
def __bool__(self) -> bool:
348348
return isinstance(self, Result.Ok)
349349

350-
@overload
351-
def unwrap(self) -> R: ...
352-
353350
@overload
354351
def unwrap(self, default: R) -> R: ...
355352

356353
@overload
357354
def unwrap[T](self, default: T) -> R | T: ...
358355

356+
@overload
357+
def unwrap(self) -> R: ...
358+
359359
def unwrap(self, default=_MISSING):
360360
match self:
361361
case Result.Ok(value):
@@ -387,7 +387,7 @@ def as_option(self) -> Option[R]:
387387
def exit(self, error_code: str | int | None = 1) -> NoReturn:
388388
sys.exit(0 if self else error_code)
389389

390-
def map[NewReturn](self, func: Callable[[R], NewReturn], bound: _ExceptionTypes, /) -> Result[NewReturn, E]:
390+
def map[NewReturn](self, func: Callable[[R], NewReturn], /, bound: _ExceptionTypes) -> Result[NewReturn, E]:
391391
match self:
392392
case Result.Ok(ok):
393393
try:
@@ -398,11 +398,32 @@ def map[NewReturn](self, func: Callable[[R], NewReturn], bound: _ExceptionTypes,
398398
else:
399399
raise
400400

401-
case Result.Err(error) as failed:
402-
if TYPE_CHECKING:
403-
return Result.Err[NewReturn, E](error)
401+
case Result.Err() as err:
402+
return err # type: ignore
403+
404+
case other:
405+
unreachable(other)
406+
407+
def flatmap[NewResult: Result](self, func: Callable[[R], NewResult], /, bound: _ExceptionTypes) -> NewResult:
408+
match self:
409+
case Result.Ok(value):
410+
try:
411+
result = func(value)
412+
except BaseException as exc:
413+
if isinstance(exc, bound):
414+
return Result.Err(exc) # type: ignore
415+
else:
416+
raise
417+
418+
if isinstance(result, Result):
419+
return result
404420
else:
405-
return failed
421+
raise TypeError(
422+
f"Expect Result but received {type(result).__name__!r}"
423+
)
424+
425+
case Result.Err() as err:
426+
return err # type: ignore
406427

407428
case other:
408429
unreachable(other)
@@ -663,5 +684,7 @@ def dump(self) -> tuple[()]: ...
663684

664685

665686
Some = Option.Some
687+
Ok = Result.Ok
688+
Err = Result.Err
666689
Success = BoundResult.Success
667690
Failed = BoundResult.Failed

tests/test_enums.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,52 @@
1+
from collections.abc import Callable
12
from typing import Any, assert_type
23
import pytest
34
from fieldenum import *
45
from fieldenum.enums import *
56
from fieldenum.exceptions import IncompatibleBoundError, UnwrapFailedError
67

78

8-
def test_option_spread_map():
9+
def test_result_maps():
10+
result = Result.Ok("hello")
11+
err = result.map(int, ValueError)
12+
assert type(err.error) == ValueError # type: ignore
13+
upper = result.map(str.upper, Exception)
14+
assert upper.unwrap() == "HELLO"
15+
with pytest.raises(ValueError):
16+
result.map(int, ArithmeticError)
17+
18+
def returns_result[T, R, E: BaseException](func: Callable[[T], R], bound: type[E]) -> Callable[[T], Result[R, E]]:
19+
def inner(value: T):
20+
try:
21+
return Ok(func(value))
22+
except bound as exc:
23+
return Err(exc)
24+
return inner
25+
26+
# flatmap with returning result
27+
result = Result.Ok("hello")
28+
err = result.flatmap(returns_result(int, ValueError), ValueError)
29+
assert type(err.error) == ValueError # type: ignore
30+
upper = result.flatmap(returns_result(str.upper, ValueError), Exception)
31+
assert upper.unwrap() == "HELLO"
32+
err = Exception()
33+
assert Result.Err(err).flatmap(returns_result(int, ValueError), ValueError).error is err # type: ignore
34+
with pytest.raises(ValueError):
35+
result.flatmap(returns_result(int, ArithmeticError), ArithmeticError) # type: ignore
36+
37+
# flatmap with raising
38+
err = result.flatmap(returns_result(int, ArithmeticError), ValueError)
39+
assert type(err.error) == ValueError # type: ignore
40+
upper = result.flatmap(returns_result(str.upper, ArithmeticError), Exception)
41+
assert upper.unwrap() == "HELLO"
42+
assert Result.Err(err).flatmap(returns_result(int, ArithmeticError), ValueError).error is err # type: ignore
43+
with pytest.raises(ValueError):
44+
result.flatmap(returns_result(int, ArithmeticError), ArithmeticError) # type: ignore
45+
with pytest.raises(TypeError):
46+
result.flatmap(str.upper, ArithmeticError) # type: ignore
47+
48+
49+
def test_option_flatmap():
950
opt = Option.new("123").map(int).flatmap(Option.new)
1051
assert_type(opt, Option[int])
1152
assert opt == Option.Some(123)

0 commit comments

Comments
 (0)