Skip to content
This repository was archived by the owner on Apr 22, 2024. It is now read-only.

Commit 3f1a0c9

Browse files
committed
IPv4 class and tests
Add structure to pack/unpack IPv4 headers, plus tests. The unpack method returns primitives instead of 'UBIntX' or 'BinaryData'. Some variables are mixed together as class variables. Those are protected so the lib user doesn't see them.
1 parent 934dbc7 commit 3f1a0c9

2 files changed

Lines changed: 134 additions & 3 deletions

File tree

pyof/foundation/network_types.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""
55

66
from pyof.foundation.base import GenericStruct
7-
from pyof.foundation.basic_types import BinaryData, HWAddress, UBInt8, UBInt16
7+
from pyof.foundation.basic_types import (BinaryData, HWAddress, IPAddress,
8+
UBInt8, UBInt16)
89
from pyof.foundation.exceptions import PackException
910

10-
__all__ = ('Ethernet', 'GenericTLV', 'TLVWithSubType', 'LLDP')
11+
__all__ = ('Ethernet', 'GenericTLV', 'IPv4', 'TLVWithSubType', 'LLDP')
1112

1213

1314
class Ethernet(GenericStruct):
@@ -138,6 +139,127 @@ def get_size(self, value=None):
138139
return 2 + self.length
139140

140141

142+
class IPv4(GenericStruct):
143+
"""IPv4 packet "struct"
144+
145+
Contains all fields of an IP version 4 packet header, plus the upper layer
146+
content as binary data.
147+
Some of the fields were merged together because of their size being
148+
inferior to 8 bits.
149+
"""
150+
151+
# IP protocol version + Internet Header Length (words)
152+
_version_ihl = UBInt8()
153+
# Differentiated Services Code Point (ToS - Type of Service) +
154+
# Explicit Congestion Notification
155+
_dscp_ecn = UBInt8()
156+
# IP packet length (bytes)
157+
length = UBInt16()
158+
# Packet ID - common to all fragments
159+
identification = UBInt16()
160+
# Fragmentation flags + fragmentation offset
161+
_flags_offset = UBInt16()
162+
# Packet time-to-live
163+
ttl = UBInt8()
164+
# Upper layer protocol number
165+
protocol = UBInt8()
166+
# Header checksum
167+
checksum = UBInt16()
168+
# Source address
169+
source = IPAddress()
170+
# Destination address
171+
destination = IPAddress()
172+
# IP Options - up to 320 bits, always padded to 32 bits
173+
options = BinaryData()
174+
# Packet data
175+
data = BinaryData()
176+
177+
def __init__(self, version=4, ihl=5, dscp=0, ecn=0, length=0, # noqa
178+
identification=0, flags=0, offset=0, ttl=255, protocol=0,
179+
checksum=0, source="0.0.0.0", destination="0.0.0.0",
180+
options=b'', data=b''):
181+
"""Create the Packet and set instance attributes."""
182+
super().__init__()
183+
self.version = version
184+
self.ihl = ihl
185+
self.dscp = dscp
186+
self.ecn = ecn
187+
self.length = length
188+
self.identification = identification
189+
self.flags = flags
190+
self.offset = offset
191+
self.ttl = ttl
192+
self.protocol = protocol
193+
self.checksum = checksum
194+
self.source = source
195+
self.destination = destination
196+
self.options = options
197+
self.data = data
198+
199+
def _update_checksum(self):
200+
"""Updates the packet checksum to enable integrity check."""
201+
source_list = [int(octet) for octet in self.source.split(".")]
202+
destination_list = [int(octet) for octet in
203+
self.destination.split(".")]
204+
source_upper = (source_list[0] << 8) + source_list[1]
205+
source_lower = (source_list[2] << 8) + source_list[3]
206+
destination_upper = (destination_list[0] << 8) + destination_list[1]
207+
destination_lower = (destination_list[2] << 8) + destination_list[3]
208+
209+
block_sum = ((self._version_ihl << 8 | self._dscp_ecn) + self.length +
210+
self.identification + self._flags_offset +
211+
(self.ttl << 8 | self.protocol) + source_upper +
212+
source_lower + destination_upper + destination_lower)
213+
214+
while block_sum > 65535:
215+
carry = block_sum >> 16
216+
block_sum = (block_sum & 65535) + carry
217+
218+
self.checksum = ~block_sum & 65535
219+
220+
def pack(self, value=None):
221+
# Set the correct IHL based on options size
222+
if self.options:
223+
self.ihl += int(len(self.options) / 4)
224+
225+
# Set the correct packet length based on header length and data
226+
self.length = int(self.ihl * 4 + len(self.data))
227+
228+
self._version_ihl = self.version << 4 | self.ihl
229+
self._dscp_ecn = self.dscp << 2 | self.ecn
230+
self._flags_offset = self.flags << 13 | self.offset
231+
232+
# Set the checksum field before packing
233+
self._update_checksum()
234+
235+
return super().pack()
236+
237+
def unpack(self, buff, offset=0):
238+
super().unpack(buff, offset)
239+
240+
self.version = self._version_ihl.value >> 4
241+
self.ihl = self._version_ihl.value & 15
242+
self.dscp = self._dscp_ecn.value >> 2
243+
self.ecn = self._dscp_ecn.value & 3
244+
self.length = self.length.value
245+
self.identification = self.identification.value
246+
self.flags = self._flags_offset.value >> 13
247+
self.offset = self._flags_offset.value & 8191
248+
self.ttl = self.ttl.value
249+
self.protocol = self.protocol.value
250+
self.checksum = self.checksum.value
251+
self.source = self.source.value
252+
self.destination = self.destination.value
253+
254+
if self.ihl > 5:
255+
options_size = (self.ihl - 5) * 4
256+
self.data = self.options.value[options_size:]
257+
self.options = self.options.value[:options_size]
258+
else:
259+
self.data = self.options.value
260+
self.options = b''
261+
262+
141263
class TLVWithSubType(GenericTLV):
142264
"""Modify the :class:`GenericTLV` to a Organization Specific TLV structure.
143265

tests/test_foundation/test_network_types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33

44
from pyof.foundation.basic_types import BinaryData
5-
from pyof.foundation.network_types import GenericTLV
5+
from pyof.foundation.network_types import GenericTLV, IPv4
66

77

88
class TestNetworkTypes(unittest.TestCase):
@@ -15,3 +15,12 @@ def test_GenTLV_value_unpack(self):
1515
tlv_unpacked = GenericTLV()
1616
tlv_unpacked.unpack(tlv.pack())
1717
self.assertEqual(tlv.value.value, tlv_unpacked.value.value)
18+
19+
def test_IPv4_pack_unpack(self):
20+
"""Test pack/unpack of IPv4 class."""
21+
packet = IPv4(ttl=64, protocol=17, source="192.168.0.1",
22+
destination="172.16.200.132", data=b'testdata')
23+
packed = packet.pack()
24+
unpacked = IPv4()
25+
unpacked.unpack(packed)
26+
self.assertEqual(packed, unpacked.pack())

0 commit comments

Comments
 (0)