Skip to content

Commit 8c706f6

Browse files
committed
Fix code style issues identified by newer versions of ruff. Also:
* Consistently name `net_connections`'s return values `wire_src_dict` and `wire_dst_dict`. * Remove some unnecessary assignments. * Update `prngs.py` to use `IntEnum`s.
1 parent 137ff35 commit 8c706f6

23 files changed

Lines changed: 267 additions & 251 deletions

pyrtl/analysis.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _multiplier_stdcell_estimate(width):
222222
def _memory_read_estimate(mem):
223223
# http://www.cs.ucsb.edu/~sherwood/pubs/ICCD-srammodel.pdf
224224
# ROM is assumed to be same delay as SRAM (perhaps optimistic?)
225-
bits, ports, is_rom = _bits_ports_and_isrom_from_memory(mem)
225+
bits, ports, _is_rom = _bits_ports_and_isrom_from_memory(mem)
226226
tech_in_um = 0.130
227227
return 270 * tech_in_um**1.38 * bits**0.25 * ports**1.30 + 1.05
228228

@@ -279,7 +279,7 @@ def critical_path(
279279
the critical paths (which themselves are lists of nets) as the second
280280
"""
281281
critical_paths = [] # storage of all completed critical paths
282-
wire_src_map, dst_map = self.block.net_connections()
282+
wire_src_dict, _wire_dst_dict = self.block.net_connections()
283283

284284
def critical_path_pass(old_critical_path, first_wire):
285285
if isinstance(first_wire, (Input, Const, Register)):
@@ -289,7 +289,7 @@ def critical_path_pass(old_critical_path, first_wire):
289289
if len(critical_paths) >= cp_limit:
290290
raise self._TooManyCPsError()
291291

292-
source = wire_src_map[first_wire]
292+
source = wire_src_dict[first_wire]
293293
critical_path = [source]
294294
critical_path.extend(old_critical_path)
295295
arg_max_time = max(self.timing_map[arg_wire] for arg_wire in source.args)
@@ -500,7 +500,7 @@ def paths(
500500
# present as the destination "net" of Output wires in the dst_nets map. That
501501
# would overly complicate this algorithm: we will assume all values() in the
502502
# dst_nets map are logic nets only. We set this to False for explicitness...
503-
_, dst_nets = block.net_connections(include_virtual_nodes=False)
503+
_wire_src_dict, dst_nets = block.net_connections(include_virtual_nodes=False)
504504
else:
505505
# ... or make sure it's not present otherwise.
506506
for output in block.wirevector_subset(cls=Output):
@@ -594,9 +594,9 @@ def fanout(w: WireVector) -> int:
594594
595595
:return: Integer fanout count.
596596
"""
597-
_, dst_nets = w._block.net_connections()
598-
if w not in dst_nets:
597+
_wire_src_dict, wire_dst_dict = w._block.net_connections()
598+
if w not in wire_dst_dict:
599599
return 0
600600

601-
all_args = [arg for net in dst_nets[w] for arg in net.args]
601+
all_args = [arg for net in wire_dst_dict[w] for arg in net.args]
602602
return len(list(filter(lambda arg: arg is w, all_args)))

pyrtl/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,8 @@ def net_connections(
565565
A :class:`WireVector`'s `source` is the :class:`LogicNet` that sets the
566566
:class:`WireVector`'s value.
567567
568-
A :class:`WireVector`'s `sinks` are the :class:`LogicNets<LogicNet>` that use
569-
the :class:`WireVector`'s value.
568+
A :class:`WireVector`'s `destinations` are the :class:`LogicNets<LogicNet>` that
569+
use the :class:`WireVector`'s value.
570570
571571
This information helps when building a graph representation for the ``Block``.
572572
See :func:`net_graph` for an example.
@@ -577,15 +577,15 @@ def net_connections(
577577
578578
:param include_virtual_nodes: If ``True``, external `sources` (such as an
579579
:class:`Inputs<Input>` and :class:`Consts<Const>`) will be represented as
580-
wires that set themselves, and external `sinks` (such as
580+
wires that set themselves, and external `destinations` (such as
581581
:class:`Outputs<Output>`) will be represented as wires that use themselves.
582582
If ``False``, these nodes will be excluded from the results.
583583
584584
:return: Two dictionaries. The first maps :class:`WireVectors<WireVector>` to
585585
the :class:`LogicNet` that creates their signal (``wire_src_dict``).
586586
The second maps :class:`WireVectors<WireVector>` to a list of
587587
:class:`LogicNets<LogicNet>` that use their signal
588-
(``wire_sink_dict``).
588+
(``wire_dst_dict``).
589589
"""
590590
src_list = {}
591591
dst_list = {}
@@ -645,16 +645,16 @@ def __iter__(self):
645645
"""
646646
from pyrtl.wire import Const, Input, Register
647647

648-
src_dict, dest_dict = self.net_connections()
648+
_wire_src_dict, wire_dst_dict = self.net_connections()
649649
to_clear = self.wirevector_subset((Input, Const, Register))
650650
cleared = set()
651651
remaining = self.logic.copy()
652652
try:
653653
while len(to_clear):
654654
wire_to_check = to_clear.pop()
655655
cleared.add(wire_to_check)
656-
if wire_to_check in dest_dict:
657-
for gate in dest_dict[
656+
if wire_to_check in wire_dst_dict:
657+
for gate in wire_dst_dict[
658658
wire_to_check
659659
]: # loop over logicnets not yet returned
660660
if all(
@@ -804,7 +804,7 @@ def sanity_check_memory_sync(self, wire_src_dict=None):
804804
return # nothing to check here
805805

806806
if wire_src_dict is None:
807-
wire_src_dict, wdd = self.net_connections()
807+
wire_src_dict, _wire_dst_dict = self.net_connections()
808808

809809
from pyrtl.wire import Const, Input
810810

pyrtl/helperfuncs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,9 @@ def val_to_formatted_str(val: int, format: str, enum_set=None) -> str:
880880
if type == "s":
881881
rval = str(val_to_signed_integer(val, bitwidth))
882882
elif type == "x":
883-
rval = hex(val)[2:] # cuts off '0x' at the start
883+
rval = f"{val:x}"
884884
elif type == "b":
885-
rval = bin(val)[2:] # cuts off '0b' at the start
885+
rval = f"{val:b}"
886886
elif type == "u":
887887
rval = str(int(val)) # nothing fancy
888888
elif type == "e":

pyrtl/passes.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def optimize(
7979
return block
8080

8181

82-
def _get_inverter_chains(wire_creator, wire_users):
82+
def _get_inverter_chains(wire_src_dict, wire_dst_dict):
8383
"""Returns all inverter chains in the block.
8484
8585
The function returns a list of inverter chains in the block. Each inverter chain is
@@ -98,7 +98,7 @@ def _get_inverter_chains(wire_creator, wire_users):
9898
# Build a list of inverter chains. Each inverter chain is a list of WireVectors,
9999
# from source to destination.
100100
inverter_chains = []
101-
for current_dest, current_creator in wire_creator.items():
101+
for current_dest, current_creator in wire_src_dict.items():
102102
if current_creator.op != "~":
103103
# Skip non-inverters.
104104
continue
@@ -107,7 +107,7 @@ def _get_inverter_chains(wire_creator, wire_users):
107107
# a WireVector).
108108
current_arg = current_creator.args[0]
109109
# current_users is the number of LogicNets that use current_dest.
110-
current_users = len(wire_users[current_dest])
110+
current_users = len(wire_dst_dict[current_dest])
111111

112112
# Add the current inverter to the end of this inverter chain.
113113
append_to = None
@@ -117,7 +117,7 @@ def _get_inverter_chains(wire_creator, wire_users):
117117
for inverter_chain in inverter_chains:
118118
chain_arg = inverter_chain[0]
119119
chain_dest = inverter_chain[-1]
120-
chain_users = len(wire_users[chain_dest])
120+
chain_users = len(wire_dst_dict[chain_dest])
121121

122122
if chain_dest is current_arg and chain_users == 1:
123123
# This chain's only destination is the current inverter. Append the
@@ -168,9 +168,9 @@ def _optimize_inverter_chains(block, skip_sanity_check=False):
168168
B -w-> Y
169169
"""
170170

171-
# wire_creator maps from WireVector to the LogicNet that defines its value.
172-
# wire_users maps from WireVector to a list of LogicNets that use its value.
173-
wire_creator, wire_users = block.net_connections()
171+
# wire_src_dict maps from WireVector to the LogicNet that defines its value.
172+
# wire_dst_dict maps from WireVector to a list of LogicNets that use its value.
173+
wire_src_dict, wire_dst_dict = block.net_connections()
174174

175175
new_logic = set()
176176
net_removal_set = set()
@@ -199,9 +199,9 @@ def _optimize_inverter_chains(block, skip_sanity_check=False):
199199
# will be mapped to A and E will be mapped to C. Hence, when finding the replacement
200200
# of E, we have to first query the dict to get C, and then query the dict again on C
201201
# to get A.
202-
wire_src_dict = _ProducerList()
202+
wire_producer = _ProducerList()
203203

204-
for inverter_chain in _get_inverter_chains(wire_creator, wire_users):
204+
for inverter_chain in _get_inverter_chains(wire_src_dict, wire_dst_dict):
205205
# If len(inverter_chain) = n, there are n-1 inverters in the chain. We only
206206
# remove inverters if there are at least two inverters in a chain.
207207
if len(inverter_chain) > 2:
@@ -215,10 +215,10 @@ def _optimize_inverter_chains(block, skip_sanity_check=False):
215215
wires_to_remove = inverter_chain[start_idx:]
216216
wire_removal_set.update(wires_to_remove)
217217
# Remove inverters used in the chain.
218-
inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove}
218+
inverters_to_remove = {wire_src_dict[wire] for wire in wires_to_remove}
219219
net_removal_set.update(inverters_to_remove)
220220
# Map the end wire of the inverter chain to the beginning wire.
221-
wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1]
221+
wire_producer[inverter_chain[-1]] = inverter_chain[start_idx - 1]
222222

223223
# This loop recreates the block with inverter chains removed. It adds each LogicNet
224224
# in the original block to the new block if it is not marked for removal, and
@@ -230,7 +230,7 @@ def _optimize_inverter_chains(block, skip_sanity_check=False):
230230
LogicNet(
231231
net.op,
232232
net.op_param,
233-
args=tuple(wire_src_dict.find_producer(x) for x in net.args),
233+
args=tuple(wire_producer.find_producer(x) for x in net.args),
234234
dests=net.dests,
235235
)
236236
)
@@ -1042,7 +1042,7 @@ def direct_connect_outputs(block=None):
10421042
# NOTE: would use transform.all_nets(), but it becomes tricky when we want to remove
10431043
# more than just the current net on a single pass
10441044
block = working_block(block)
1045-
_, dst_nets = block.net_connections()
1045+
_wire_src_dict, wire_dst_dict = block.net_connections()
10461046

10471047
nets_to_remove = set()
10481048
nets_to_add = set()
@@ -1053,10 +1053,10 @@ def direct_connect_outputs(block=None):
10531053
continue
10541054

10551055
dest_wire = net.dests[0]
1056-
if dest_wire not in dst_nets or len(dst_nets[dest_wire]) > 1:
1056+
if dest_wire not in wire_dst_dict or len(wire_dst_dict[dest_wire]) > 1:
10571057
continue
10581058

1059-
dst_net = dst_nets[dest_wire][0]
1059+
dst_net = wire_dst_dict[dest_wire][0]
10601060
if dst_net.op != "w" or not isinstance(dst_net.dests[0], Output):
10611061
continue
10621062

@@ -1105,7 +1105,7 @@ def two_way_fanout(block=None):
11051105

11061106
block = working_block(block)
11071107

1108-
_, dst_map = block.net_connections()
1108+
_wire_src_dict, wire_dst_dict = block.net_connections()
11091109
# Two-pass approach: Remember which nets will need to change, in case there are
11101110
# multiple arguments which will be changing along the way.
11111111
nets_to_update = collections.defaultdict(list)
@@ -1114,7 +1114,7 @@ def two_way_fanout(block=None):
11141114
if curr_fanout > 1:
11151115
s = _make_tree(wire, block, curr_fanout)
11161116
curr_ix = 0
1117-
for dst_net in dst_map[wire]:
1117+
for dst_net in wire_dst_dict[wire]:
11181118
for i, arg in enumerate(dst_net.args):
11191119
if arg is wire:
11201120
nets_to_update[dst_net].append((wire, i, s[curr_ix]))

pyrtl/rtllib/matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1570,7 +1570,7 @@ def matrix_wv_to_list(
15701570
15711571
:return: A Python list of lists.
15721572
"""
1573-
value = bin(matrix_wv)[2:].zfill(rows * columns * bits)
1573+
value = f"{matrix_wv:b}".zfill(rows * columns * bits)
15741574

15751575
result = [[0 for _ in range(columns)] for _ in range(rows)]
15761576

pyrtl/rtllib/muxes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def _check_finalized(self):
235235

236236
def option(self, select_val, *data_signals):
237237
self._check_finalized()
238-
instr, ib = pyrtl.infer_val_and_bitwidth(select_val, self.signal_wire.bitwidth)
238+
instr = pyrtl.infer_val_and_bitwidth(
239+
select_val, self.signal_wire.bitwidth
240+
).value
239241
if instr in self.instructions:
240242
msg = f"instruction {select_val} already exists"
241243
raise pyrtl.PyrtlError(msg)

pyrtl/rtllib/prngs.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import random
23
from math import ceil, log2
34

@@ -145,26 +146,30 @@ def prng_xoroshiro128(
145146
counter = pyrtl.Register(counter_bitwidth, "counter")
146147
gen_done = counter == gen_cycles - 1
147148
state = pyrtl.Register(1)
148-
WAIT, GEN = (pyrtl.Const(x) for x in range(2))
149+
150+
class State(enum.IntEnum):
151+
WAIT = 0
152+
GEN = 1
153+
149154
with pyrtl.conditional_assignment:
150155
with load:
151156
s0.next |= seed[:64]
152157
s1.next |= seed[64:]
153-
state.next |= WAIT
158+
state.next |= State.WAIT
154159
with req:
155160
counter.next |= 0
156161
s0.next |= s0_next
157162
s1.next |= s1_next
158163
rand.next |= pyrtl.concat(rand, output)
159-
state.next |= GEN
160-
with state == GEN:
164+
state.next |= State.GEN
165+
with state == State.GEN:
161166
with ~gen_done:
162167
counter.next |= counter + 1
163168
s0.next |= s0_next
164169
s1.next |= s1_next
165170
rand.next |= pyrtl.concat(rand, output)
166171

167-
ready = ~load & ~req & (state == GEN) & gen_done
172+
ready = ~load & ~req & (state == State.GEN) & gen_done
168173
return ready, rand[-bitwidth:] # return MSBs because LSBs are less random
169174

170175

@@ -263,34 +268,43 @@ def csprng_trivium(
263268
init_done = counter == init_cycles
264269
gen_done = counter == gen_cycles - 1
265270
state = pyrtl.Register(2)
266-
WAIT, INIT, GEN = (pyrtl.Const(x) for x in range(3))
271+
272+
class State(enum.IntEnum):
273+
WAIT = 0
274+
INIT = 1
275+
GEN = 2
276+
267277
with pyrtl.conditional_assignment:
268278
with load:
269279
counter.next |= 0
270280
a.next |= key
271281
b.next |= iv
272282
c.next |= pyrtl.concat(pyrtl.Const("3'b111"), pyrtl.Const(0, 108))
273-
state.next |= INIT
283+
state.next |= State.INIT
274284
with req:
275285
counter.next |= 0
276286
a.next |= a_next
277287
b.next |= b_next
278288
c.next |= c_next
279289
rand.next |= pyrtl.concat(rand, *output)
280-
state.next |= GEN
281-
with state == INIT:
290+
state.next |= State.GEN
291+
with state == State.INIT:
282292
with ~init_done:
283293
counter.next |= counter + 1
284294
a.next |= a_next
285295
b.next |= b_next
286296
c.next |= c_next
287-
with state == GEN:
297+
with state == State.GEN:
288298
with ~gen_done:
289299
counter.next |= counter + 1
290300
a.next |= a_next
291301
b.next |= b_next
292302
c.next |= c_next
293303
rand.next |= pyrtl.concat(rand, *output)
294304

295-
ready = ~load & ~req & ((state == INIT) & init_done | (state == GEN) & gen_done)
305+
ready = (
306+
~load
307+
& ~req
308+
& ((state == State.INIT) & init_done | (state == State.GEN) & gen_done)
309+
)
296310
return ready, rand

pyrtl/simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1594,7 +1594,7 @@ def __len__(self):
15941594
raise PyrtlError(msg)
15951595
# return the length of the list of some element in the dictionary (all should be
15961596
# the same)
1597-
wire, value_list = next(x for x in self.trace.items())
1597+
_wire, value_list = next(x for x in self.trace.items())
15981598
return len(value_list)
15991599

16001600
def add_step(self, value_map):

pyrtl/transform.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ def wire_transform(
7373
:param block: The Block to replace wires on. Defaults to the :ref:`working_block`.
7474
"""
7575
block = working_block(block)
76-
src_nets, dst_nets = block.net_connections(include_virtual_nodes=False)
76+
wire_src_dict, wire_dst_dict = block.net_connections(include_virtual_nodes=False)
7777
for orig_wire in block.wirevector_subset(select_types, exclude_types):
7878
new_src, new_dst = transform_func(orig_wire)
79-
replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block)
79+
replace_wire_fast(
80+
orig_wire, new_src, new_dst, wire_src_dict, wire_dst_dict, block
81+
)
8082

8183

8284
def all_wires(transform_func):
@@ -96,9 +98,9 @@ def replace_wires(wire_map, block=None):
9698
:param block: block to operate over (defaults to :ref:`working_block`)
9799
"""
98100
block = working_block(block)
99-
src_nets, dst_nets = block.net_connections(include_virtual_nodes=False)
101+
wire_src_dict, wire_dst_dict = block.net_connections(include_virtual_nodes=False)
100102
for old_w, new_w in wire_map.items():
101-
replace_wire_fast(old_w, new_w, new_w, src_nets, dst_nets, block)
103+
replace_wire_fast(old_w, new_w, new_w, wire_src_dict, wire_dst_dict, block)
102104

103105

104106
def replace_wire_fast(orig_wire, new_src, new_dst, src_nets, dst_nets, block=None):

tests/rtllib/test_libutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_successful_partition(self):
1616
def test_failing_partition(self):
1717
w = pyrtl.WireVector(14)
1818
with self.assertRaises(pyrtl.PyrtlError):
19-
_ = libutils.partition_wire(w, 4)
19+
libutils.partition_wire(w, 4)
2020

2121
def test_partition_sim(self):
2222
pyrtl.reset_working_block()

0 commit comments

Comments
 (0)