Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions kmir/src/kmir/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .build import HASKELL_DEF_DIR, LLVM_DEF_DIR, LLVM_LIB_DIR
from .cargo import CargoProject
from .kmir import KMIR, DecodeMode, KMIRAPRNodePrinter
from .kmir import KMIR, KMIRAPRNodePrinter
from .linker import link
from .options import (
GenSpecOpts,
Expand Down Expand Up @@ -352,13 +352,6 @@ def _arg_parser() -> ArgumentParser:
prove_rs_parser.add_argument(
'--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from'
)
prove_rs_parser.add_argument(
'--decode-mode',
type=DecodeMode,
metavar='DECODE_MODE',
default=DecodeMode.NONE,
help='Allocation decoding mode: NONE (default), PARTIAL, or FULL',
)

link_parser = command_parser.add_parser(
'link', help='Link together 2 or more SMIR JSON files', parents=[kcli_args.logging_args]
Expand Down Expand Up @@ -435,7 +428,6 @@ def _parse_args(ns: Namespace) -> KMirOpts:
save_smir=ns.save_smir,
smir=ns.smir,
start_symbol=ns.start_symbol,
decode_mode=ns.decode_mode,
)
case 'link':
return LinkOpts(
Expand Down
14 changes: 7 additions & 7 deletions kmir/src/kmir/alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,21 @@ def from_dict(dct: dict[str, Any]) -> Allocation:

@dataclass
class ProvenanceMap:
ptrs: list[ProvenanceItem]
ptrs: list[ProvenanceEntry]

@staticmethod
def from_dict(dct: dict[str, Any]) -> ProvenanceMap:
return ProvenanceMap(
ptrs=[
ProvenanceItem(
size=int(size),
prov=AllocId(prov),
ProvenanceEntry(
offset=int(size),
alloc_id=AllocId(prov),
)
for size, prov in dct['ptrs']
],
)


class ProvenanceItem(NamedTuple):
size: int
prov: AllocId
class ProvenanceEntry(NamedTuple):
offset: int
alloc_id: AllocId
125 changes: 86 additions & 39 deletions kmir/src/kmir/decoding.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,46 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING

from pyk.kast.inner import KApply
from pyk.kast.prelude.bytes import bytesToken
from pyk.kast.prelude.kint import intToken

from .alloc import Allocation, AllocInfo, Memory, ProvenanceMap
from .ty import ArrayT, Bool, EnumT, Int, IntTy, Uint
from .value import AggregateValue, BoolValue, IntValue, RangeValue, Value
from pyk.kast.prelude.string import stringToken

from .alloc import Allocation, AllocInfo, Memory, ProvenanceEntry, ProvenanceMap
from .ty import ArrayT, Bool, EnumT, Int, IntTy, PtrT, RefT, Str, Uint
from .value import (
NO_METADATA,
AggregateValue,
AllocRefValue,
BoolValue,
DynamicSize,
IntValue,
RangeValue,
StaticSize,
StrValue,
Value,
)

if TYPE_CHECKING:
from collections.abc import Mapping

from pyk.kast import KInner

from .alloc import AllocId
from .ty import Ty, TypeMetadata, UintTy
from .value import Metadata


@dataclass
class UnableToDecodeValue(Value):
data: bytes
type_info: TypeMetadata
msg: str

def to_kast(self) -> KInner:
return KApply(
'Evaluation::UnableToDecodeValue',
bytesToken(self.data),
KApply('TypeInfo::VoidType'), # TODO: TypeInfo -> KAST transformation
'Evaluation::UnableToDecodePy',
stringToken(self.msg),
)


@dataclass
class UnableToDecodeAlloc(Value):
data: bytes
ty: Ty

def to_kast(self) -> KInner:
return KApply(
'Evaluation::UnableToDecodeAlloc',
bytesToken(self.data),
KApply('ty', intToken(self.ty)),
KApply('ProvenanceMapEntries::empty'), # TODO
)


class ProvenanceMapEntry(NamedTuple):
offset: int
alloc_id: AllocId


def decode_alloc_or_unable(alloc_info: AllocInfo, types: Mapping[Ty, TypeMetadata]) -> Value:
match alloc_info:
case AllocInfo(
Expand All @@ -66,27 +55,81 @@ def decode_alloc_or_unable(alloc_info: AllocInfo, types: Mapping[Ty, TypeMetadat
),
):
data = bytes(n or 0 for n in bytez)
return _decode_memory_alloc_or_unable(data=data, ptrs=ptrs, ty=ty, types=types)
case _:
raise AssertionError('Unhandled case')


if not ptrs: # TODO generalize to lists with at most one entry
type_info = types[ty]
return decode_value_or_unable(data=data, type_info=type_info, types=types)
def _decode_memory_alloc_or_unable(
data: bytes,
ptrs: list[ProvenanceEntry],
ty: Ty,
types: Mapping[Ty, TypeMetadata],
) -> Value:
try:
type_info = types[ty]
except KeyError:
return UnableToDecodeValue(f'Unknown type: {ty}')

return UnableToDecodeAlloc(data=data, ty=ty)
match ptrs:
case []:
return decode_value_or_unable(data=data, type_info=type_info, types=types)

case [ProvenanceEntry(0, alloc_id)]:
if (pointee_ty := _pointee_ty(type_info)) is not None: # ensures this is a reference type
try:
pointee_type_info = types[pointee_ty]
except KeyError:
return UnableToDecodeValue(f'Unknown pointee type: {pointee_ty}')

metadata = _metadata(pointee_type_info)

if len(data) == 8:
# single slim pointer (assumes usize == u64)
return AllocRefValue(alloc_id=alloc_id, metadata=metadata)

if len(data) == 16 and metadata == DynamicSize(1):
# sufficient data to decode dynamic size (assumes usize == u64)
# expect fat pointer
return AllocRefValue(
alloc_id=alloc_id,
metadata=DynamicSize(int.from_bytes(data[8:16], byteorder='little', signed=False)),
)

return UnableToDecodeValue(f'Unable to decode alloc: {data!r}, of type: {type_info}')


def _pointee_ty(type_info: TypeMetadata) -> Ty | None:
match type_info:
case PtrT(ty) | RefT(ty):
return ty
case _:
raise AssertionError('Unhandled case')
return None


def _metadata(type_info: TypeMetadata) -> Metadata:
match type_info:
case ArrayT(length=None):
return DynamicSize(1) # 1 is a placeholder, the actual size is inferred from the slice data
case ArrayT(length=int() as length):
return StaticSize(length)
case _:
return NO_METADATA


def decode_value_or_unable(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMetadata]) -> Value:
try:
return decode_value(data=data, type_info=type_info, types=types)
except ValueError:
return UnableToDecodeValue(data=data, type_info=type_info)
except ValueError as err:
return UnableToDecodeValue(f'Unable to decode value: {data!r}, of type: {type_info}: {err}')


def decode_value(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMetadata]) -> Value:
match type_info:
case Bool():
return _decode_bool(data)
case Str():
return _decode_str(data)
case Uint(int_ty) | Int(int_ty):
return _decode_int(data, int_ty)
case ArrayT(elem_ty, length):
Expand All @@ -107,6 +150,10 @@ def _decode_bool(data: bytes) -> Value:
raise ValueError(f'Cannot decode as Bool: {data!r}')


def _decode_str(data: bytes) -> Value:
return StrValue(data.decode('utf-8'))


def _decode_int(data: bytes, int_ty: IntTy | UintTy) -> Value:
nbytes = int_ty.value
if len(data) != nbytes:
Expand Down
2 changes: 1 addition & 1 deletion kmir/src/kmir/kdist/mir-semantics/kmir.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ will be `129`.

rule #switchMatch(0, BoolVal(B) ) => notBool B
rule #switchMatch(1, BoolVal(B) ) => B
rule #switchMatch(I, Integer(I2, WIDTH, _)) => I ==Int bitRangeInt(I2, 0, WIDTH)
rule #switchMatch(I, Integer(I2, WIDTH, _)) => I ==Int truncate(I2, WIDTH, Unsigned)
```

`Return` simply returns from a function call, using the information
Expand Down
26 changes: 11 additions & 15 deletions kmir/src/kmir/kdist/mir-semantics/lemmas/kmir-lemmas.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ Therefore, its value range should be simplified for symbolic input asserted to b

```k
rule truncate(VAL, WIDTH, Unsigned) => VAL
requires VAL <Int (1 <<Int WIDTH)
requires 0 <Int WIDTH
andBool VAL <Int (1 <<Int WIDTH)
andBool 0 <=Int VAL
[simplification]
[simplification, preserves-definedness] // , smt-lemma], but `Unsigned` needs to be smtlib

rule truncate(VAL, WIDTH, Signed) => VAL
requires VAL <Int (1 <<Int (WIDTH -Int 1))
requires 0 <Int WIDTH
andBool VAL <Int (1 <<Int (WIDTH -Int 1))
andBool 0 -Int (1 <<Int (WIDTH -Int 1)) <=Int VAL
[simplification]
[simplification, preserves-definedness] // , smt-lemma], but `Signed` needs to be smtlib
```

However, `truncate` gets evaluated and is therefore not present any more for this simplification.
Expand All @@ -97,17 +99,6 @@ The following simplification rules operate on the expression created by evaluati
power of two but the semantics will always operate with these particular ones.

```k
rule VAL &Int MASK => VAL
requires 0 <=Int VAL
andBool VAL <=Int MASK
andBool ( MASK ==Int bitmask8
orBool MASK ==Int bitmask16
orBool MASK ==Int bitmask32
orBool MASK ==Int bitmask64
orBool MASK ==Int bitmask128
)
[simplification, preserves-definedness]

syntax Int ::= "bitmask8" [macro]
| "bitmask16" [macro]
| "bitmask32" [macro]
Expand All @@ -120,6 +111,11 @@ power of two but the semantics will always operate with these particular ones.
rule bitmask64 => ( 1 <<Int 64 ) -Int 1
rule bitmask128 => ( 1 <<Int 128) -Int 1

rule VAL &Int bitmask8 => VAL requires 0 <=Int VAL andBool VAL <=Int bitmask8 [simplification, preserves-definedness, smt-lemma]
rule VAL &Int bitmask16 => VAL requires 0 <=Int VAL andBool VAL <=Int bitmask16 [simplification, preserves-definedness, smt-lemma]
rule VAL &Int bitmask32 => VAL requires 0 <=Int VAL andBool VAL <=Int bitmask32 [simplification, preserves-definedness, smt-lemma]
rule VAL &Int bitmask64 => VAL requires 0 <=Int VAL andBool VAL <=Int bitmask64 [simplification, preserves-definedness, smt-lemma]
rule VAL &Int bitmask128 => VAL requires 0 <=Int VAL andBool VAL <=Int bitmask128 [simplification, preserves-definedness, smt-lemma]
```


Expand Down
Loading