mirror of
https://github.com/espressif/openthread.git
synced 2026-06-06 05:24:51 +00:00
[tests] support testing commissioning (#3441)
Add support for testing commissioning process, including: 1. Add DTLS, Thread Discovery and commissioning messages(JOIN_FIN.req etc) parsing; 2. Support parsing log from UART response, and construct decrypted messages; 3. Common commands for commissioning related verification;
This commit is contained in:
+10
-10
@@ -416,20 +416,20 @@ build_samr21() {
|
||||
export ASAN_OPTIONS=symbolize=1 || die
|
||||
export DISTCHECK_CONFIGURE_FLAGS= CPPFLAGS=-DOPENTHREAD_POSIX_VIRTUAL_TIME=1 || die
|
||||
./bootstrap || die
|
||||
make -f examples/Makefile-posix distcheck || die
|
||||
CERT_LOG=1 make -f examples/Makefile-posix distcheck || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != posix-32-bit ] || {
|
||||
./bootstrap || die
|
||||
COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 make -f examples/Makefile-posix check || die
|
||||
CERT_LOG=1 COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 make -f examples/Makefile-posix check || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != posix-app-cli ] || {
|
||||
./bootstrap || die
|
||||
# enable code coverage for OpenThread transceiver only
|
||||
COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
|
||||
COVERAGE=1 make -f src/posix/Makefile-posix || die
|
||||
COVERAGE=1 PYTHONUNBUFFERED=1 OT_CLI_PATH="$(pwd)/$(ls output/posix/*/bin/ot-cli)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" make -f src/posix/Makefile-posix check || die
|
||||
CERT_LOG=1 COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
|
||||
CERT_LOG=1 COVERAGE=1 make -f src/posix/Makefile-posix || die
|
||||
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 OT_CLI_PATH="$(pwd)/$(ls output/posix/*/bin/ot-cli)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" make -f src/posix/Makefile-posix check || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != posix-app-pty ] || {
|
||||
@@ -439,7 +439,7 @@ build_samr21() {
|
||||
|
||||
[ $BUILD_TARGET != posix-mtd ] || {
|
||||
./bootstrap || die
|
||||
COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 USE_MTD=1 make -f examples/Makefile-posix check || die
|
||||
CERT_LOG=1 COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 USE_MTD=1 make -f examples/Makefile-posix check || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != posix-ncp-spi ] || {
|
||||
@@ -449,15 +449,15 @@ build_samr21() {
|
||||
|
||||
[ $BUILD_TARGET != posix-app-ncp ] || {
|
||||
./bootstrap || die
|
||||
COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
|
||||
CERT_LOG=1 COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
|
||||
# enable code coverage for OpenThread posix radio
|
||||
COVERAGE=1 make -f src/posix/Makefile-posix || die
|
||||
COVERAGE=1 PYTHONUNBUFFERED=1 OT_NCP_PATH="$(pwd)/$(ls output/posix/*/bin/ot-ncp)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" NODE_TYPE=ncp-sim make -f src/posix/Makefile-posix check || die
|
||||
CERT_LOG=1 COVERAGE=1 make -f src/posix/Makefile-posix || die
|
||||
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 OT_NCP_PATH="$(pwd)/$(ls output/posix/*/bin/ot-ncp)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" NODE_TYPE=ncp-sim make -f src/posix/Makefile-posix check || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != posix-ncp ] || {
|
||||
./bootstrap || die
|
||||
COVERAGE=1 PYTHONUNBUFFERED=1 NODE_TYPE=ncp-sim make -f examples/Makefile-posix check || die
|
||||
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 NODE_TYPE=ncp-sim make -f examples/Makefile-posix check || die
|
||||
}
|
||||
|
||||
[ $BUILD_TARGET != toranj-test-framework ] || {
|
||||
|
||||
@@ -35,6 +35,7 @@ AM_DISTCHECK_CONFIGURE_FLAGS = \
|
||||
--enable-application-coap \
|
||||
--enable-application-coap-secure \
|
||||
--enable-border-router \
|
||||
--enable-cert-log \
|
||||
--enable-cli \
|
||||
--enable-commissioner \
|
||||
--enable-dhcp6-client \
|
||||
|
||||
@@ -125,9 +125,11 @@ EXTRA_DIST = \
|
||||
command.py \
|
||||
common.py \
|
||||
config.py \
|
||||
dtls.py \
|
||||
ipv6.py \
|
||||
lowpan.py \
|
||||
mac802154.py \
|
||||
mesh_cop.py \
|
||||
message.py \
|
||||
mle.py \
|
||||
net_crypto.py \
|
||||
|
||||
@@ -33,11 +33,14 @@ import sys
|
||||
import ipv6
|
||||
import network_data
|
||||
import network_layer
|
||||
import common
|
||||
import config
|
||||
import mesh_cop
|
||||
import mle
|
||||
|
||||
from collections import Counter
|
||||
from enum import IntEnum
|
||||
from network_data import Prefix, BorderRouter, LowpanId
|
||||
|
||||
class CheckType(IntEnum):
|
||||
CONTAIN = 0
|
||||
@@ -270,6 +273,7 @@ def check_parent_request(command_msg, is_first_request):
|
||||
elif not scan_mask.end_device:
|
||||
raise ValueError("Second parent request without E bit set")
|
||||
|
||||
|
||||
def check_parent_response(command_msg, mle_frame_counter = CheckType.OPTIONAL):
|
||||
"""Verify a properly formatted Parent Response command message.
|
||||
"""
|
||||
@@ -478,3 +482,72 @@ def check_address_registration_tlv(addr_reg_tlv, address_set):
|
||||
"""Verify all addresses contained in address_set are contained in add_reg_tlv
|
||||
"""
|
||||
assert all(addr in addr_reg_tlv.addresses for addr in address_set), 'Some addresses are not included in AddressRegistration TLV'
|
||||
|
||||
def assert_contains_tlv(tlvs, check_type, tlv_type):
|
||||
"""Assert a tlv list contains specific tlv and return the first qualified.
|
||||
"""
|
||||
tlvs = [tlv for tlv in tlvs if isinstance(tlv, tlv_type)]
|
||||
if check_type is CheckType.CONTAIN:
|
||||
assert tlvs
|
||||
return tlvs[0]
|
||||
elif check_type is CheckType.NOT_CONTAIN:
|
||||
assert not tlvs
|
||||
return None
|
||||
elif check_type is CheckType.OPTIONAL:
|
||||
return None
|
||||
else:
|
||||
raise ValueError("Invalid check type: {}".format(check_type))
|
||||
|
||||
def check_discovery_request(command_msg):
|
||||
"""Verify a properly formatted Thread Discovery Request command message.
|
||||
"""
|
||||
assert not isinstance(command_msg.mle, mle.MleMessageSecured)
|
||||
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
|
||||
request = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryRequest)
|
||||
assert request.version == config.PROTOCOL_VERSION
|
||||
|
||||
def check_discovery_response(command_msg, request_src_addr, steering_data=CheckType.OPTIONAL):
|
||||
"""Verify a properly formatted Thread Discovery Response command message.
|
||||
"""
|
||||
assert not isinstance(command_msg.mle, mle.MleMessageSecured)
|
||||
assert command_msg.mac_header.src_address.type == common.MacAddressType.LONG
|
||||
assert command_msg.mac_header.dest_address == request_src_addr
|
||||
|
||||
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
|
||||
response = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryResponse)
|
||||
assert response.version == config.PROTOCOL_VERSION
|
||||
assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.ExtendedPanid)
|
||||
assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.NetworkName)
|
||||
assert_contains_tlv(tlvs, steering_data, network_data.SteeringData)
|
||||
assert_contains_tlv(tlvs, steering_data, mesh_cop.JoinerUdpPort)
|
||||
|
||||
check_type = CheckType.CONTAIN if response.native_flag else CheckType.OPTIONAL
|
||||
assert_contains_tlv(tlvs, check_type, network_data.CommissionerUdpPort)
|
||||
|
||||
def get_joiner_udp_port_in_discovery_response(command_msg):
|
||||
"""Get the udp port specified in a DISCOVERY RESPONSE message
|
||||
"""
|
||||
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
|
||||
udp_port_tlv = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.JoinerUdpPort)
|
||||
return udp_port_tlv.udp_port
|
||||
|
||||
def check_joiner_commissioning_messages(commissioning_messages):
|
||||
"""Verify COAP messages sent by joiner while commissioning process.
|
||||
"""
|
||||
print(commissioning_messages)
|
||||
assert len(commissioning_messages) >= 2
|
||||
join_fin_req = commissioning_messages[0]
|
||||
assert join_fin_req.type == mesh_cop.MeshCopMessageType.JOIN_FIN_REQ
|
||||
assert_contains_tlv(join_fin_req.tlvs, CheckType.NOT_CONTAIN, mesh_cop.ProvisioningUrl)
|
||||
join_ent_rsp = commissioning_messages[1]
|
||||
assert join_ent_rsp.type == mesh_cop.MeshCopMessageType.JOIN_ENT_RSP
|
||||
|
||||
def check_commissioner_commissioning_messages(commissioning_messages):
|
||||
"""Verify COAP messages sent by commissioner while commissioning process.
|
||||
"""
|
||||
assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_FIN_RSP for msg in commissioning_messages)
|
||||
|
||||
def check_joiner_router_commissioning_messages(commissioning_messages):
|
||||
"""Verify COAP messages sent by joiner router while commissioning process.
|
||||
"""
|
||||
assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_ENT_NTF for msg in commissioning_messages)
|
||||
|
||||
@@ -30,8 +30,10 @@ import os
|
||||
from enum import Enum
|
||||
|
||||
import coap
|
||||
import dtls
|
||||
import ipv6
|
||||
import lowpan
|
||||
import mesh_cop
|
||||
import message
|
||||
import mle
|
||||
import net_crypto
|
||||
@@ -75,6 +77,8 @@ VIRTUAL_TIME = int(os.getenv('VIRTUAL_TIME', 0))
|
||||
|
||||
LEADER_NOTIFY_SED_BY_CHILD_UPDATE_REQUEST = True
|
||||
|
||||
PROTOCOL_VERSION = 2
|
||||
|
||||
def create_default_network_data_prefix_sub_tlvs_factories():
|
||||
return {
|
||||
network_data.TlvType.HAS_ROUTE: network_data.HasRouteFactory(
|
||||
@@ -150,6 +154,24 @@ def create_default_mle_tlv_address_registration_factory():
|
||||
addr_compressed_factory=mle.AddressCompressedFactory(),
|
||||
addr_full_factory=mle.AddressFullFactory())
|
||||
|
||||
def create_default_mle_tlv_thread_discovery_factory():
|
||||
return mle.ThreadDiscoveryFactory(
|
||||
thread_discovery_tlvs_factory=create_default_thread_discovery_tlvs_factory())
|
||||
|
||||
def create_default_thread_discovery_tlvs_factory():
|
||||
return mesh_cop.ThreadDiscoveryTlvsFactory(
|
||||
sub_tlvs_factories=create_default_thread_discovery_sub_tlvs_factories())
|
||||
|
||||
def create_default_thread_discovery_sub_tlvs_factories():
|
||||
return {
|
||||
mesh_cop.TlvType.DISCOVERY_REQUEST: mesh_cop.DiscoveryRequestFactory(),
|
||||
mesh_cop.TlvType.DISCOVERY_RESPONSE: mesh_cop.DiscoveryResponseFactory(),
|
||||
mesh_cop.TlvType.EXTENDED_PANID: mesh_cop.ExtendedPanidFactory(),
|
||||
mesh_cop.TlvType.NETWORK_NAME: mesh_cop.NetworkNameFactory(),
|
||||
mesh_cop.TlvType.STEERING_DATA: network_data.SteeringDataFactory(),
|
||||
mesh_cop.TlvType.JOINER_UDP_PORT: mesh_cop.JoinerUdpPortFactory(),
|
||||
mesh_cop.TlvType.COMMISSIONER_UDP_PORT: network_data.CommissionerUdpPortFactory()
|
||||
}
|
||||
|
||||
def create_default_mle_tlvs_factories():
|
||||
return {
|
||||
@@ -177,9 +199,9 @@ def create_default_mle_tlvs_factories():
|
||||
mle.TlvType.PENDING_TIMESTAMP: mle.PendingTimestampFactory(),
|
||||
mle.TlvType.ACTIVE_OPERATIONAL_DATASET: mle.ActiveOperationalDatasetFactory(),
|
||||
mle.TlvType.PENDING_OPERATIONAL_DATASET: mle.PendingOperationalDatasetFactory(),
|
||||
mle.TlvType.THREAD_DISCOVERY: mle.ThreadDiscoveryFactory(),
|
||||
mle.TlvType.TIME_REQUEST: mle.TimeRequestFactory(),
|
||||
mle.TlvType.TIME_PARAMETER: mle.TimeParameterFactory(),
|
||||
mle.TlvType.THREAD_DISCOVERY: create_default_mle_tlv_thread_discovery_factory()
|
||||
}
|
||||
|
||||
|
||||
@@ -253,15 +275,21 @@ def create_default_ipv6_hop_by_hop_options_factory():
|
||||
def create_default_based_on_src_dst_ports_udp_payload_factory(master_key):
|
||||
mle_message_factory = create_default_mle_message_factory(master_key)
|
||||
coap_message_factory = create_default_coap_message_factory()
|
||||
dtls_message_factory = create_default_dtls_message_factory()
|
||||
|
||||
return ipv6.UdpBasedOnSrcDstPortsPayloadFactory(
|
||||
src_dst_port_based_payload_factories={
|
||||
19788: mle_message_factory,
|
||||
61631: coap_message_factory
|
||||
61631: coap_message_factory,
|
||||
1000: dtls_message_factory
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_default_dtls_message_factory():
|
||||
return dtls.MessageFactory()
|
||||
|
||||
|
||||
def create_default_ipv6_icmp_body_factories():
|
||||
return {
|
||||
ipv6.ICMP_DESTINATION_UNREACHABLE: ipv6.ICMPv6DestinationUnreachableFactory(),
|
||||
|
||||
@@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) 2019, The OpenThread Authors.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# 1. Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# 3. Neither the name of the copyright holder nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
from functools import reduce
|
||||
import io
|
||||
import math
|
||||
import struct
|
||||
|
||||
from ipv6 import BuildableFromBytes
|
||||
from ipv6 import ConvertibleToBytes
|
||||
|
||||
|
||||
class HandshakeType(IntEnum):
|
||||
HELLO_REQUEST = 0
|
||||
CLIENT_HELLO = 1
|
||||
SERVER_HELLO = 2
|
||||
HELLO_VERIFY_REQUEST = 3
|
||||
CERTIFICATE = 11
|
||||
SERVER_KEY_EXCHANGE = 12
|
||||
CERTIFICATE_REQUEST = 13
|
||||
SERVER_HELLO_DONE = 14
|
||||
CERTIFICATE_VERIFY = 15
|
||||
CLIENT_KEY_EXCHANGE = 16
|
||||
FINISHED = 20
|
||||
|
||||
|
||||
class ContentType(IntEnum):
|
||||
CHANGE_CIPHER_SPEC = 20
|
||||
ALERT = 21
|
||||
HANDSHAKE = 22
|
||||
APPLICATION_DATA = 23
|
||||
|
||||
|
||||
class AlertLevel(IntEnum):
|
||||
WARNING = 1
|
||||
FATAL = 2
|
||||
|
||||
|
||||
class AlertDescription(IntEnum):
|
||||
CLOSE_NOTIFY = 0
|
||||
UNEXPECTED_MESSAGE = 10
|
||||
BAD_RECORD_MAC = 20
|
||||
DECRYPTION_FAILED_RESERVED = 21
|
||||
RECORD_OVERFLOW = 22
|
||||
DECOMPRESSION_FAILURE = 30
|
||||
HANDSHAKE_FAILURE = 40
|
||||
NO_CERTIFICATE_RESERVED = 41
|
||||
BAD_CERTIFICATE = 42
|
||||
UNSUPPORTED_CERTIFICATE = 43
|
||||
CERTIFICATE_REVOKED = 44
|
||||
CERTIFICATE_EXPIRED = 45
|
||||
CERTIFICATE_UNKNOWN = 46
|
||||
ILLEGAL_PARAMETER = 47
|
||||
UNKNOWN_CA = 48
|
||||
ACCESS_DENIED = 49
|
||||
DECODE_ERROR = 50
|
||||
DECRYPT_ERROR = 51
|
||||
EXPORT_RESTRICTION_RESERVED = 60
|
||||
PROTOCOL_VERSION = 70
|
||||
INSUFFICIENT_SECURITY = 71
|
||||
INTERNAL_ERROR = 80
|
||||
USER_CANCELED = 90
|
||||
NO_RENEGOTIATION = 100
|
||||
UNSUPPORTED_EXTENSION = 110
|
||||
|
||||
|
||||
class Record(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, content_type, version, epoch,
|
||||
sequence_number, length, fragment):
|
||||
self.content_type = content_type
|
||||
self.version = version
|
||||
self.epoch = epoch
|
||||
self.sequence_number = sequence_number
|
||||
self.length = length
|
||||
self.fragment = fragment
|
||||
|
||||
def to_bytes(self):
|
||||
return (struct.pack(">B", self.content_type) +
|
||||
self.version.to_bytes() +
|
||||
struct.pack(">H", self.epoch) +
|
||||
self.sequence_number.to_bytes(6, byteorder='big') +
|
||||
struct.pack(">H", self.length) +
|
||||
self.fragment)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
content_type = ContentType(struct.unpack(">B", data.read(1))[0])
|
||||
version = ProtocolVersion.from_bytes(data)
|
||||
epoch = struct.unpack(">H", data.read(2))[0]
|
||||
sequence_number = struct.unpack(">Q", b'\x00\x00' + data.read(6))[0]
|
||||
length = struct.unpack(">H", data.read(2))[0]
|
||||
fragment = bytes(data.read(length))
|
||||
return cls(content_type, version, epoch, sequence_number, length, fragment)
|
||||
|
||||
def __repr__(self):
|
||||
return "Record(content_type={}, version={}, epoch={}, sequence_number={}, length={})".format(
|
||||
str(self.content_type), self.version, self.epoch, self.sequence_number, self.length)
|
||||
|
||||
|
||||
class Message(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, content_type):
|
||||
self.content_type = content_type
|
||||
|
||||
def to_bytes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HandshakeMessage(Message):
|
||||
|
||||
def __init__(self, handshake_type, length, message_seq,
|
||||
fragment_offset, fragment_length, body):
|
||||
super(HandshakeMessage, self).__init__(ContentType.HANDSHAKE)
|
||||
self.handshake_type = handshake_type
|
||||
self.length = length
|
||||
self.message_seq = message_seq
|
||||
self.fragment_offset = fragment_offset
|
||||
self.fragment_length = fragment_length
|
||||
self.body = body
|
||||
|
||||
def to_bytes(self):
|
||||
return (struct.pack(">B", self.handshake_type) +
|
||||
struct.pack(">I", self.length)[1:] +
|
||||
struct.pack(">H", self.message_seq) +
|
||||
struct.pack(">I", self.fragment_offset)[1:] +
|
||||
struct.pack(">I", self.fragment_length)[1:] +
|
||||
self.body.to_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
handshake_type = HandshakeType(struct.unpack(">B", data.read(1))[0])
|
||||
length = struct.unpack(">I", b'\x00' + data.read(3))[0]
|
||||
message_seq = struct.unpack(">H", data.read(2))[0]
|
||||
fragment_offset = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
|
||||
fragment_length = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
|
||||
end_position = data.tell() + fragment_length
|
||||
# TODO(wgtdkp): handle fragmentation
|
||||
|
||||
message_class, body = handshake_map[handshake_type], None
|
||||
if message_class:
|
||||
body = message_class.from_bytes(data)
|
||||
else:
|
||||
print("{} messages are not handled".format(str(handshake_type)))
|
||||
body = bytes(data.read(fragment_length))
|
||||
assert data.tell() == end_position
|
||||
|
||||
return cls(handshake_type, length, message_seq,
|
||||
fragment_offset, fragment_length, body)
|
||||
|
||||
def __repr__(self):
|
||||
return "Handshake(type={}, length={})".format(str(self.handshake_type), self.length)
|
||||
|
||||
|
||||
class ProtocolVersion(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, major, minor):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) == type(other) and
|
||||
self.major == other.major and
|
||||
self.minor == other.minor)
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">BB", self.major, self.minor)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
major, minor = struct.unpack(">BB", data.read(2))
|
||||
return cls(major, minor)
|
||||
|
||||
def __repr__(self):
|
||||
return "ProtocolVersion(major={}, minor={})".format(self.major, self.minor)
|
||||
|
||||
|
||||
class Random(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
random_bytes_length = 28
|
||||
|
||||
def __init__(self, gmt_unix_time, random_bytes):
|
||||
self.gmt_unix_time = gmt_unix_time
|
||||
self.random_bytes = random_bytes
|
||||
assert len(self.random_bytes) == Random.random_bytes_length
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) == type(other) and
|
||||
self.gmt_unix_time == other.gmt_unix_time and
|
||||
self.random_bytes == other.random_bytes)
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">I", self.gmt_unix_time) + (self.random_bytes)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
gmt_unix_time = struct.unpack(">I", data.read(4))[0]
|
||||
random_bytes = bytes(data.read(cls.random_bytes_length))
|
||||
return cls(gmt_unix_time, random_bytes)
|
||||
|
||||
|
||||
class VariableVector(ConvertibleToBytes):
|
||||
|
||||
def __init__(self, subrange, ele_cls, elements):
|
||||
self.subrange = subrange
|
||||
self.ele_cls = ele_cls
|
||||
self.elements = elements
|
||||
assert self.subrange[0] <= len(self.elements) <= self.subrange[1]
|
||||
|
||||
def length(self):
|
||||
return len(self.elements)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) == type(other) and
|
||||
self.subrange == other.subrange and
|
||||
self.ele_cls == other.ele_cls and
|
||||
self.elements == other.elements)
|
||||
|
||||
def to_bytes(self):
|
||||
data = reduce(lambda ele, acc: acc + ele.to_bytes(), self.elements)
|
||||
return VariableVector._encode_length(len(data), self.subrange) + data
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, ele_cls, subrange, data):
|
||||
length = cls._decode_length(subrange, data)
|
||||
end_position = data.tell() + length
|
||||
elements = []
|
||||
while data.tell() < end_position:
|
||||
elements.append(ele_cls.from_bytes(data))
|
||||
return cls(subrange, ele_cls, elements)
|
||||
|
||||
@classmethod
|
||||
def _decode_length(cls, subrange, data):
|
||||
length_in_byte = cls._calc_length_in_byte(subrange[1])
|
||||
return reduce(lambda acc, byte: (acc << 8) | byte, bytearray(data.read(length_in_byte)), 0)
|
||||
|
||||
@classmethod
|
||||
def _encode_length(cls, length, subrange):
|
||||
length_in_byte = cls._calc_length_in_byte(subrange[1])
|
||||
ret = bytearray([])
|
||||
while length_in_byte > 0:
|
||||
ret += bytes(length_in_byte & 0xff)
|
||||
length_in_byte == length_in_byte >> 8
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _calc_length_in_byte(cls, ceiling):
|
||||
return (ceiling.bit_length() + 7) // 8
|
||||
|
||||
|
||||
class Opaque(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, byte):
|
||||
self.byte = byte
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) == type(other) and self.byte == other.byte
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">B", self.byte)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
return cls(struct.unpack(">B", data.read(1))[0])
|
||||
|
||||
|
||||
class CipherSuite(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, cipher):
|
||||
self.cipher = cipher
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) == type(other) and self.cipher == other.cipher
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">BB", self.cipher[0], self.cipher[1])
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
return cls(struct.unpack(">BB", data.read(2)))
|
||||
|
||||
def __repr__(self):
|
||||
return "CipherSuite({}, {})".format(self.cipher[0], self.cipher[1])
|
||||
|
||||
|
||||
class CompressionMethod(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
NULL = 0
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) == type(other)
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">B", CompressionMethod.NULL)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
method = struct.unpack(">B", data.read(1))[0]
|
||||
assert method == cls.NULL
|
||||
return cls()
|
||||
|
||||
|
||||
class Extension(ConvertibleToBytes, BuildableFromBytes):
|
||||
|
||||
def __init__(self, extension_type, extension_data):
|
||||
self.extension_type = extension_type
|
||||
self.extension_data = extension_data
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) == type(other) and
|
||||
self.extension_type == other.extension_type and
|
||||
self.extension_data == other.extension_data)
|
||||
|
||||
def to_bytes(self):
|
||||
return (struct.pack(">H", self.extension_type) +
|
||||
self.extension_data.to_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
extension_type = struct.unpack(">H", data.read(2))[0]
|
||||
extension_data = VariableVector.from_bytes(Opaque, (0, 2**16 - 1), data)
|
||||
return cls(extension_type, extension_data)
|
||||
|
||||
|
||||
class ClientHello(HandshakeMessage):
|
||||
|
||||
def __init__(self, client_version, random, session_id,
|
||||
cookie, cipher_suites, compression_methods, extensions):
|
||||
self.client_version = client_version
|
||||
self.random = random
|
||||
self.session_id = session_id
|
||||
self.cookie = cookie
|
||||
self.cipher_suites = cipher_suites
|
||||
self.compression_methods = compression_methods
|
||||
self.extensions = extensions
|
||||
|
||||
def to_bytes(self):
|
||||
return (self.client_version.to_bytes() +
|
||||
self.random.to_bytes() +
|
||||
self.session_id.to_bytes() +
|
||||
self.cookie.to_bytes() +
|
||||
self.cipher_suites.to_bytes() +
|
||||
self.compression_methods.to_bytes() +
|
||||
self.extensions.to_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
client_version = ProtocolVersion.from_bytes(data)
|
||||
random = Random.from_bytes(data)
|
||||
session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
|
||||
cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
|
||||
cipher_suites = VariableVector.from_bytes(CipherSuite, (2, 2**16 - 1), data)
|
||||
compression_methods = VariableVector.from_bytes(CompressionMethod, (1, 2**8 - 1), data)
|
||||
extensions = None
|
||||
if data.tell() < len(data.getvalue()):
|
||||
extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
|
||||
return cls(client_version, random, session_id,
|
||||
cookie, cipher_suites, compression_methods, extensions)
|
||||
|
||||
|
||||
class HelloVerifyRequest(HandshakeMessage):
|
||||
|
||||
def __init__(self, server_version, cookie):
|
||||
self.server_version = server_version
|
||||
self.cookie = cookie
|
||||
|
||||
def to_bytes(self):
|
||||
return self.server_version.to_bytes() + self.cookie.to_bytes()
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
server_version = ProtocolVersion.from_bytes(data)
|
||||
cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
|
||||
return cls(server_version, cookie)
|
||||
|
||||
|
||||
class ServerHello(HandshakeMessage):
|
||||
|
||||
def __init__(self, server_version, random, session_id,
|
||||
cipher_suite, compression_method, extensions):
|
||||
self.server_version = server_version
|
||||
self.random = random
|
||||
self.session_id = session_id
|
||||
self.cipher_suite = cipher_suite
|
||||
self.compression_method = compression_method
|
||||
self.extensions = extensions
|
||||
|
||||
def to_bytes(self):
|
||||
return (self.server_version.to_bytes() +
|
||||
self.random.to_bytes() +
|
||||
self.session_id.to_bytes() +
|
||||
self.cipher_suite.to_bytes() +
|
||||
self.compression_method.to_bytes() +
|
||||
self.extensions.to_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
server_version = ProtocolVersion.from_bytes(data)
|
||||
random = Random.from_bytes(data)
|
||||
session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
|
||||
cipher_suite = CipherSuite.from_bytes(data)
|
||||
compression_method = CompressionMethod.from_bytes(data)
|
||||
extensions = None
|
||||
if data.tell() < len(data.getvalue()):
|
||||
extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
|
||||
return cls(server_version, random, session_id,
|
||||
cipher_suite, compression_method, extensions)
|
||||
|
||||
|
||||
class ServerHelloDone(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def to_bytes(self):
|
||||
return bytearray([])
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
return cls()
|
||||
|
||||
|
||||
class HelloRequest(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Certificate(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ServerKeyExchange(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CertificateRequest(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CertificateVerify(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ClientKeyExchange(HandshakeMessage):
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Finished(HandshakeMessage):
|
||||
|
||||
def __init__(self, verify_data):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AlertMessage(Message):
|
||||
|
||||
def __init__(self, level, description):
|
||||
super(AlertMessage, self).__init__(ContentType.ALERT)
|
||||
self.level = level
|
||||
self.description = description
|
||||
|
||||
def to_bytes(self):
|
||||
struct.pack(">BB", self.level, self.description)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
level, description = struct.unpack(">BB", data.read(2))
|
||||
try:
|
||||
return cls(AlertLevel(level), AlertDescription(description))
|
||||
except:
|
||||
data.read()
|
||||
# An AlertMessage could be encrypted and we can't parsing it.
|
||||
return cls(None, None)
|
||||
|
||||
def __repr__(self):
|
||||
return "Alert(level={}, description={})".format(
|
||||
str(self.level), str(self.description))
|
||||
|
||||
|
||||
class ChangeCipherSpecMessage(Message):
|
||||
|
||||
def __init__(self):
|
||||
super(ChangeCipherSpecMessage, self).__init__(ContentType.CHANGE_CIPHER_SPEC)
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(">B", 1)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
assert struct.unpack(">B", data.read(1))[0] == 1
|
||||
return cls()
|
||||
|
||||
def __repr__(self):
|
||||
return "ChangeCipherSpec(value=1)"
|
||||
|
||||
|
||||
class ApplicationDataMessage(Message):
|
||||
|
||||
def __init__(self, raw):
|
||||
super(ApplicationDataMessage, self).__init__(ContentType.APPLICATION_DATA)
|
||||
self.raw = raw
|
||||
self.body = None
|
||||
|
||||
def to_bytes(self):
|
||||
return self.raw
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data):
|
||||
# It is safe to read until the end of this byte stream, because
|
||||
# there is single application data message in a record.
|
||||
length = len(data.getvalue()) - data.tell()
|
||||
return cls(bytes(data.read(length)))
|
||||
|
||||
def __repr__(self):
|
||||
if self.body:
|
||||
return "ApplicationData(body={})".format(self.body)
|
||||
else:
|
||||
return "ApplicationData(raw_length={})".format(len(self.raw))
|
||||
|
||||
|
||||
handshake_map = {
|
||||
HandshakeType.HELLO_REQUEST: None, # HelloRequest
|
||||
HandshakeType.CLIENT_HELLO: ClientHello,
|
||||
HandshakeType.SERVER_HELLO: ServerHello,
|
||||
HandshakeType.HELLO_VERIFY_REQUEST: HelloVerifyRequest,
|
||||
HandshakeType.CERTIFICATE: None, # Certificate
|
||||
HandshakeType.SERVER_KEY_EXCHANGE: None, # ServerKeyExchange
|
||||
HandshakeType.CERTIFICATE_REQUEST: None, # CertificateRequest
|
||||
HandshakeType.SERVER_HELLO_DONE: ServerHelloDone,
|
||||
HandshakeType.CERTIFICATE_VERIFY: None, # CertificateVerify
|
||||
HandshakeType.CLIENT_KEY_EXCHANGE: None, # ClientKeyExchange
|
||||
HandshakeType.FINISHED: None, # Finished
|
||||
}
|
||||
|
||||
|
||||
content_map = {
|
||||
ContentType.CHANGE_CIPHER_SPEC: ChangeCipherSpecMessage,
|
||||
ContentType.ALERT: AlertMessage,
|
||||
ContentType.HANDSHAKE: HandshakeMessage,
|
||||
ContentType.APPLICATION_DATA: ApplicationDataMessage
|
||||
}
|
||||
|
||||
|
||||
class MessageFactory(object):
|
||||
|
||||
last_msg_is_change_cipher_spec = False
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse(self, data, message_info):
|
||||
messages = []
|
||||
|
||||
# Multiple records could be sent in the same UDP datagram
|
||||
while data.tell() < len(data.getvalue()):
|
||||
record = Record.from_bytes(data)
|
||||
|
||||
if record.version.major != 0xfe or record.version.minor != 0xfd:
|
||||
raise ValueError("DTLS version error, expect DTLSv1.2")
|
||||
|
||||
last_msg_is_change_cipher_spec = type(self).last_msg_is_change_cipher_spec
|
||||
type(self).last_msg_is_change_cipher_spec = (record.content_type == ContentType.CHANGE_CIPHER_SPEC)
|
||||
|
||||
# FINISHED message immediately follows CHANGE_CIPHER_SPEC message
|
||||
# We skip FINISHED message as it is encrypted
|
||||
if last_msg_is_change_cipher_spec:
|
||||
continue
|
||||
|
||||
fragment_data = io.BytesIO(record.fragment)
|
||||
|
||||
# Multiple handshake messages could be sent in the same record
|
||||
while fragment_data.tell() < len(fragment_data.getvalue()):
|
||||
content_class = content_map[record.content_type]
|
||||
assert content_class
|
||||
messages.append(content_class.from_bytes(fragment_data))
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,468 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) 2019, The OpenThread Authors.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
# 1. Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# 3. Neither the name of the copyright holder nor the
|
||||
# names of its contributors may be used to endorse or promote products
|
||||
# derived from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
|
||||
from enum import IntEnum
|
||||
import io
|
||||
import struct
|
||||
|
||||
from network_data import SubTlvsFactory
|
||||
|
||||
class TlvType(IntEnum):
|
||||
EXTENDED_PANID = 2
|
||||
NETWORK_NAME = 3
|
||||
STEERING_DATA = 8
|
||||
COMMISSIONER_UDP_PORT = 15
|
||||
STATE = 16
|
||||
JOINER_UDP_PORT = 18
|
||||
PROVISIONING_URL = 32
|
||||
VENDOR_NAME = 33
|
||||
VENDOR_MODEL = 34
|
||||
VENDOR_SW_VERSION = 35
|
||||
VENDOR_DATA = 36
|
||||
VENDOR_STACK_VERSION = 37
|
||||
DISCOVERY_REQUEST = 128
|
||||
DISCOVERY_RESPONSE = 129
|
||||
|
||||
|
||||
class MeshCopMessageType(IntEnum):
|
||||
JOIN_FIN_REQ = 1
|
||||
JOIN_FIN_RSP = 2
|
||||
JOIN_ENT_NTF = 3
|
||||
JOIN_ENT_RSP = 4
|
||||
|
||||
|
||||
def create_mesh_cop_message_type_set():
|
||||
return [ MeshCopMessageType.JOIN_FIN_REQ,
|
||||
MeshCopMessageType.JOIN_FIN_RSP,
|
||||
MeshCopMessageType.JOIN_ENT_NTF,
|
||||
MeshCopMessageType.JOIN_ENT_RSP ]
|
||||
|
||||
|
||||
class State(object):
|
||||
|
||||
def __init__(self, state):
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.state == other.state
|
||||
|
||||
def __repr__(self):
|
||||
return "State(state={})".format(self.state)
|
||||
|
||||
|
||||
class StateFactory:
|
||||
|
||||
def parse(self, data):
|
||||
state = ord(data.read(1))
|
||||
return State(state)
|
||||
|
||||
|
||||
class VendorName(object):
|
||||
|
||||
def __init__(self, vendor_name):
|
||||
self._vendor_name = vendor_name
|
||||
|
||||
@property
|
||||
def vendor_name(self):
|
||||
return self._vendor_name
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.vendor_name == other.vendor_name
|
||||
|
||||
def __repr__(self):
|
||||
return "VendorName(vendor_name={})".format(self.vendor_name)
|
||||
|
||||
|
||||
class VendorNameFactory:
|
||||
|
||||
def parse(self, data):
|
||||
vendor_name = data.getvalue().decode('utf-8')
|
||||
return VendorName(vendor_name)
|
||||
|
||||
|
||||
class VendorModel(object):
|
||||
|
||||
def __init__(self, vendor_model):
|
||||
self._vendor_model = vendor_model
|
||||
|
||||
@property
|
||||
def vendor_model(self):
|
||||
return self._vendor_model
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.vendor_model == other.vendor_model
|
||||
|
||||
def __repr__(self):
|
||||
return "VendorModel(vendor_model={})".format(self.vendor_model)
|
||||
|
||||
|
||||
class VendorModelFactory:
|
||||
|
||||
def parse(self, data):
|
||||
vendor_model = data.getvalue().decode('utf-8')
|
||||
return VendorModel(vendor_model)
|
||||
|
||||
class VendorSWVersion(object):
|
||||
|
||||
def __init__(self, vendor_sw_version):
|
||||
self._vendor_sw_version = vendor_sw_version
|
||||
|
||||
@property
|
||||
def vendor_sw_version(self):
|
||||
return self._vendor_sw_version
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.vendor_sw_version == other.vendor_sw_version
|
||||
|
||||
def __repr__(self):
|
||||
return "VendorName(vendor_sw_version={})".format(self.vendor_sw_version)
|
||||
|
||||
|
||||
class VendorSWVersionFactory:
|
||||
|
||||
def parse(self, data):
|
||||
vendor_sw_version = data.getvalue()
|
||||
return VendorSWVersion(vendor_sw_version)
|
||||
|
||||
|
||||
# VendorStackVersion TLV (37)
|
||||
class VendorStackVersion(object):
|
||||
|
||||
def __init__(self, stack_vendor_oui, build, rev, minor, major):
|
||||
self._stack_vendor_oui = stack_vendor_oui
|
||||
self._build = build
|
||||
self._rev = rev
|
||||
self._minor = minor
|
||||
self._major = major
|
||||
return
|
||||
|
||||
@property
|
||||
def stack_vendor_oui(self):
|
||||
return self._stack_vendor_oui
|
||||
|
||||
@property
|
||||
def build(self):
|
||||
return self._build
|
||||
|
||||
@property
|
||||
def rev(self):
|
||||
return self._rev
|
||||
|
||||
@property
|
||||
def minor(self):
|
||||
return self._minor
|
||||
|
||||
@property
|
||||
def major(self):
|
||||
return self._major
|
||||
|
||||
def __repr__(self):
|
||||
return "VendorStackVersion(vendor_stack_version={}, build={}, rev={}, minor={}, major={})".format(self.stack_vendor_oui, self.build, self.rev, self.minor, self.major)
|
||||
|
||||
|
||||
class VendorStackVersionFactory:
|
||||
|
||||
def parse(self, data):
|
||||
stack_vendor_oui = struct.unpack(">H", data.read(2))[0]
|
||||
rest = struct.unpack(">BBBB", data.read(4))
|
||||
build = rest[1] << 4 | (0xf0 & rest[2])
|
||||
rev = 0xf & rest[2]
|
||||
minor = rest[3] & 0xf0
|
||||
major = rest[3] & 0xf
|
||||
return VendorStackVersion(stack_vendor_oui, build, rev, minor, major)
|
||||
|
||||
|
||||
class ProvisioningUrl(object):
|
||||
|
||||
def __init__(self, url):
|
||||
self._url = url
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._url
|
||||
|
||||
def __repr__(self):
|
||||
return "ProvisioningUrl(url={})".format(self.url)
|
||||
|
||||
|
||||
class ProvisioningUrlFactory:
|
||||
|
||||
def parse(self, data):
|
||||
url = data.decode('utf-8')
|
||||
return ProvisioningUrl(url)
|
||||
|
||||
|
||||
class VendorData(object):
|
||||
|
||||
def __init__(self, data):
|
||||
self._vendor_data = data
|
||||
|
||||
@property
|
||||
def vendor_data(self):
|
||||
return self._vendor_data
|
||||
|
||||
def __repr__(self):
|
||||
return "Vendor(url={})".format(self.vendor_data)
|
||||
|
||||
|
||||
class VendorDataFactory(object):
|
||||
|
||||
def parse(self, data):
|
||||
return VendorData(data)
|
||||
|
||||
|
||||
class MeshCopCommand(object):
|
||||
|
||||
def __init__(self, _type, tlvs):
|
||||
self._type = _type
|
||||
self._tlvs = tlvs
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def tlvs(self):
|
||||
return self._tlvs
|
||||
|
||||
def __repr__(self):
|
||||
tlvs_str = ", ".join(["{}".format(tlv) for tlv in self.tlvs])
|
||||
return "MeshCopCommand(type={}, tlvs=[{}])".format(self.type, tlvs_str)
|
||||
|
||||
|
||||
def create_deault_mesh_cop_msg_type_map():
|
||||
return {
|
||||
'JOIN_FIN.req': MeshCopMessageType.JOIN_FIN_REQ,
|
||||
'JOIN_FIN.rsp': MeshCopMessageType.JOIN_FIN_RSP,
|
||||
'JOIN_ENT.ntf': MeshCopMessageType.JOIN_ENT_NTF,
|
||||
'JOIN_ENT.rsp': MeshCopMessageType.JOIN_ENT_RSP
|
||||
}
|
||||
|
||||
|
||||
class MeshCopCommandFactory:
|
||||
|
||||
def __init__(self, tlvs_factories):
|
||||
self._tlvs_factories = tlvs_factories
|
||||
self._mesh_cop_msg_type_map = create_deault_mesh_cop_msg_type_map()
|
||||
|
||||
def _get_length(self, data):
|
||||
return ord(data.read(1))
|
||||
|
||||
def _get_tlv_factory(self, _type):
|
||||
try:
|
||||
return self._tlvs_factories[_type]
|
||||
except KeyError:
|
||||
raise KeyError("Could not find TLV factory. Unsupported TLV type: {}".format(_type))
|
||||
|
||||
def _parse_tlv(self, data):
|
||||
_type = TlvType(ord(data.read(1)))
|
||||
length = self._get_length(data)
|
||||
value = data.read(length)
|
||||
factory = self._get_tlv_factory(_type)
|
||||
if factory == None:
|
||||
return None
|
||||
return factory.parse(io.BytesIO(value))
|
||||
|
||||
def _get_mesh_cop_msg_type(self, msg_type_str):
|
||||
tp = self._mesh_cop_msg_type_map[msg_type_str]
|
||||
if tp == None:
|
||||
raise RuntimeError('Mesh cop message type not found: {}'.format(msg_type_str))
|
||||
return tp
|
||||
|
||||
def parse(self, cmd_type_str, data):
|
||||
cmd_type = self._get_mesh_cop_msg_type(cmd_type_str)
|
||||
|
||||
tlvs = []
|
||||
while data.tell() < len(data.getvalue()):
|
||||
tlv = self._parse_tlv(data)
|
||||
tlvs.append(tlv)
|
||||
|
||||
return MeshCopCommand(cmd_type, tlvs)
|
||||
|
||||
|
||||
def create_default_mesh_cop_tlv_factories():
|
||||
return {
|
||||
TlvType.STATE: StateFactory(),
|
||||
TlvType.PROVISIONING_URL: ProvisioningUrlFactory(),
|
||||
TlvType.VENDOR_NAME: VendorNameFactory(),
|
||||
TlvType.VENDOR_MODEL: VendorModelFactory(),
|
||||
TlvType.VENDOR_SW_VERSION: VendorSWVersionFactory(),
|
||||
TlvType.VENDOR_DATA: VendorDataFactory(),
|
||||
TlvType.VENDOR_STACK_VERSION: VendorStackVersionFactory()
|
||||
}
|
||||
|
||||
|
||||
class ThreadDiscoveryTlvsFactory(SubTlvsFactory):
|
||||
|
||||
def __init__(self, sub_tlvs_factories):
|
||||
super(ThreadDiscoveryTlvsFactory, self).__init__(sub_tlvs_factories)
|
||||
|
||||
|
||||
class DiscoveryRequest(object):
|
||||
|
||||
def __init__(self, version, joiner_flag):
|
||||
self._version = version
|
||||
self._joiner_flag = joiner_flag
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def joiner_flag(self):
|
||||
return self._joiner_flag
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.version == other.version
|
||||
and self.joiner_flag == other.joiner_flag)
|
||||
|
||||
def __repr__(self):
|
||||
return "DiscoveryRequest(version={}, joiner_flag={})".format(
|
||||
self.version, self.joiner_flag)
|
||||
|
||||
|
||||
class DiscoveryRequestFactory(object):
|
||||
|
||||
def parse(self, data, message_info):
|
||||
data_byte = struct.unpack(">B", data.read(1))[0]
|
||||
version = (data_byte & 0xf0) >> 4
|
||||
joiner_flag = (data_byte & 0x08) >> 3
|
||||
|
||||
return DiscoveryRequest(version, joiner_flag)
|
||||
|
||||
|
||||
class DiscoveryResponse(object):
|
||||
|
||||
def __init__(self, version, native_flag):
|
||||
self._version = version
|
||||
self._native_flag = native_flag
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def native_flag(self):
|
||||
return self._native_flag
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.version == other.version
|
||||
and self.native_flag == other.native_flag)
|
||||
|
||||
def __repr__(self):
|
||||
return "DiscoveryResponse(version={}, native_flag={})".format(
|
||||
self.version, self.native_flag)
|
||||
|
||||
|
||||
class DiscoveryResponseFactory(object):
|
||||
|
||||
def parse(self, data, message_info):
|
||||
data_byte = struct.unpack(">B", data.read(1))[0]
|
||||
version = (data_byte & 0xf0) >> 4
|
||||
native_flag = (data_byte & 0x08) >> 3
|
||||
|
||||
return DiscoveryResponse(version, native_flag)
|
||||
|
||||
|
||||
class ExtendedPanid(object):
|
||||
|
||||
def __init__(self, extended_panid):
|
||||
self._extended_panid = extended_panid
|
||||
|
||||
@property
|
||||
def extended_panid(self):
|
||||
return self._extended_panid
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.extended_panid == other.extended_panid)
|
||||
|
||||
def __repr__(self):
|
||||
return "ExtendedPanid(extended_panid={})".format(self.extended_panid)
|
||||
|
||||
|
||||
class ExtendedPanidFactory(object):
|
||||
|
||||
def parse(self, data, message_info):
|
||||
extended_panid = struct.unpack(">Q", data.read(8))[0]
|
||||
return ExtendedPanid(extended_panid)
|
||||
|
||||
|
||||
class NetworkName(object):
|
||||
|
||||
def __init__(self, network_name):
|
||||
self._network_name = network_name
|
||||
|
||||
@property
|
||||
def network_name(self):
|
||||
return self._network_name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.network_name == other.network_name)
|
||||
|
||||
def __repr__(self):
|
||||
return "NetworkName(network_name={})".format(self.network_name)
|
||||
|
||||
|
||||
class NetworkNameFactory(object):
|
||||
|
||||
def parse(self, data, message_info):
|
||||
len = message_info.length
|
||||
network_name = struct.unpack("{}s".format(10), data.read(len))[0]
|
||||
return NetworkName(network_name)
|
||||
|
||||
|
||||
class JoinerUdpPort(object):
|
||||
|
||||
def __init__(self, udp_port):
|
||||
self._udp_port = udp_port
|
||||
|
||||
@property
|
||||
def udp_port(self):
|
||||
return self._udp_port
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(self) is type(other) and self.udp_port == other.udp_port
|
||||
|
||||
def __repr__(self):
|
||||
return "JoinerUdpPort(udp_port={})".format(self.udp_port)
|
||||
|
||||
|
||||
class JoinerUdpPortFactory(object):
|
||||
|
||||
def parse(self, data, message_info):
|
||||
udp_port = struct.unpack(">H", data.read(2))[0]
|
||||
return JoinerUdpPort(udp_port)
|
||||
@@ -34,6 +34,7 @@ import sys
|
||||
|
||||
import coap
|
||||
import common
|
||||
import dtls
|
||||
import ipv6
|
||||
import lowpan
|
||||
import mac802154
|
||||
@@ -49,6 +50,7 @@ class MessageType(IntEnum):
|
||||
BEACON = 4
|
||||
DATA = 5
|
||||
COMMAND = 6
|
||||
DTLS = 7
|
||||
|
||||
|
||||
class Message(object):
|
||||
@@ -61,6 +63,7 @@ class Message(object):
|
||||
self._coap = None
|
||||
self._mle = None
|
||||
self._icmp = None
|
||||
self._dtls = None
|
||||
|
||||
def _extract_udp_datagram(self, udp_datagram):
|
||||
if isinstance(udp_datagram.payload, mle.MleMessage):
|
||||
@@ -71,6 +74,11 @@ class Message(object):
|
||||
self._type = MessageType.COAP
|
||||
self._coap = udp_datagram.payload
|
||||
|
||||
# DTLS message factory returns a list of messages
|
||||
elif isinstance(udp_datagram.payload, list):
|
||||
self._type = MessageType.DTLS
|
||||
self._dtls = udp_datagram.payload
|
||||
|
||||
def _extract_upper_layer_protocol(self, upper_layer_protocol):
|
||||
if isinstance(upper_layer_protocol, ipv6.ICMPv6):
|
||||
self._type = MessageType.ICMP
|
||||
@@ -79,6 +87,32 @@ class Message(object):
|
||||
elif isinstance(upper_layer_protocol, ipv6.UDPDatagram):
|
||||
self._extract_udp_datagram(upper_layer_protocol)
|
||||
|
||||
def try_extract_dtls_messages(self):
|
||||
"""Extract multiple dtls messages that are sent in a single UDP datagram
|
||||
"""
|
||||
if self.type != MessageType.DTLS:
|
||||
return [self.clone()]
|
||||
|
||||
assert isinstance(self.dtls, list)
|
||||
ret = []
|
||||
for dtls in self.dtls:
|
||||
msg = self.clone()
|
||||
msg._dtls = dtls
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
def clone(self):
|
||||
msg = Message()
|
||||
msg._type = self.type
|
||||
msg._channel = self.channel
|
||||
msg._mac_header = self.mac_header
|
||||
msg._ipv6_packet = self.ipv6_packet
|
||||
msg._coap = self.coap
|
||||
msg._mle = self.mle
|
||||
msg._icmp = self.icmp
|
||||
msg._dtls = self.dtls
|
||||
return msg
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
@@ -141,6 +175,10 @@ class Message(object):
|
||||
def icmp(self, value):
|
||||
self._icmp = value
|
||||
|
||||
@property
|
||||
def dtls(self):
|
||||
return self._dtls
|
||||
|
||||
def get_mle_message_tlv(self, tlv_class_type):
|
||||
if self.type != MessageType.MLE:
|
||||
raise ValueError("Invalid message type. Expected MLE message.")
|
||||
@@ -306,19 +344,30 @@ class Message(object):
|
||||
def isMacAddressTypeLong(self):
|
||||
return self.mac_header.dest_address.type == common.MacAddressType.LONG
|
||||
|
||||
def get_dst_udp_port(self):
|
||||
assert isinstance(self.ipv6_packet.upper_layer_protocol, ipv6.UDPDatagram)
|
||||
return self.ipv6_packet.upper_layer_protocol.header.dst_port
|
||||
|
||||
def __repr__(self):
|
||||
if self.type == MessageType.DTLS and self.dtls.content_type == dtls.ContentType.HANDSHAKE:
|
||||
return "Message(type={})".format(str(self.dtls.handshake_type))
|
||||
return "Message(type={})".format(MessageType(self.type).name)
|
||||
|
||||
|
||||
class MessagesSet(object):
|
||||
|
||||
def __init__(self, messages):
|
||||
def __init__(self, messages, commissioning_messages=[]):
|
||||
self._messages = messages
|
||||
self._commissioning_messages = commissioning_messages
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def commissioning_messages(self):
|
||||
return self._commissioning_messages
|
||||
|
||||
def next_coap_message(self, code, uri_path=None, assert_enabled=True):
|
||||
message = None
|
||||
|
||||
@@ -433,6 +482,21 @@ class MessagesSet(object):
|
||||
def next_command_message(self):
|
||||
return self.next_message_of(MessageType.COMMAND)
|
||||
|
||||
def next_dtls_message(self, content_type, handshake_type=None):
|
||||
while self.messages:
|
||||
msg = self.messages.pop(0)
|
||||
if msg.type != MessageType.DTLS:
|
||||
continue
|
||||
if msg.dtls.content_type != content_type:
|
||||
continue
|
||||
if (content_type == dtls.ContentType.HANDSHAKE and
|
||||
msg.dtls.handshake_type != handshake_type):
|
||||
continue
|
||||
return msg
|
||||
|
||||
t = handshake_type if content_type == dtls.ContentType.HANDSHAKE else content_type
|
||||
raise ValueError("Could not find DTLS message of type: {}".format(str(t)))
|
||||
|
||||
def contains_icmp_message(self):
|
||||
for m in self.messages:
|
||||
if m.type == MessageType.ICMP:
|
||||
@@ -472,7 +536,10 @@ class MessagesSet(object):
|
||||
def clone(self):
|
||||
"""Make a copy of current MessageSet.
|
||||
"""
|
||||
return MessagesSet(self.messages[:])
|
||||
return MessagesSet(self.messages[:], self.commissioning_messages[:])
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.messages)
|
||||
|
||||
|
||||
class MessageFactory:
|
||||
@@ -506,7 +573,7 @@ class MessageFactory:
|
||||
message.mac_header = mac_frame.header
|
||||
|
||||
if message.mac_header.frame_type != mac802154.MacHeader.FrameType.DATA:
|
||||
return message
|
||||
return [message]
|
||||
|
||||
message_info = common.MessageInfo()
|
||||
message_info.source_mac_address = message.mac_header.src_address
|
||||
@@ -517,11 +584,11 @@ class MessageFactory:
|
||||
|
||||
ipv6_packet = self._lowpan_parser.parse(lowpan_payload, message_info)
|
||||
if ipv6_packet is None:
|
||||
return message
|
||||
return [message]
|
||||
|
||||
message.ipv6_packet = ipv6_packet
|
||||
|
||||
if message.type == MessageType.MLE:
|
||||
self._add_device_descriptors(message)
|
||||
|
||||
return message
|
||||
return message.try_extract_dtls_messages()
|
||||
|
||||
@@ -1009,17 +1009,30 @@ class PendingOperationalDatasetFactory:
|
||||
return PendingOperationalDataset()
|
||||
|
||||
|
||||
class ThreadDiscovery:
|
||||
# TODO: Not implemented yet
|
||||
class ThreadDiscovery(object):
|
||||
|
||||
def __init__(self):
|
||||
print("ThreadDiscovery is not implemented yet.")
|
||||
def __init__(self, tlvs):
|
||||
self._tlvs = tlvs
|
||||
|
||||
@property
|
||||
def tlvs(self):
|
||||
return self._tlvs
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.tlvs == other.tlvs
|
||||
|
||||
def __repr__(self):
|
||||
return "ThreadDiscovery(tlvs={})".format(self.tlvs)
|
||||
|
||||
class ThreadDiscoveryFactory:
|
||||
|
||||
def __init__(self, thread_discovery_tlvs_factory):
|
||||
self._tlvs_factory = thread_discovery_tlvs_factory
|
||||
|
||||
def parse(self, data, message_info):
|
||||
return ThreadDiscovery()
|
||||
tlvs = self._tlvs_factory.parse(data, message_info)
|
||||
return ThreadDiscovery(tlvs)
|
||||
|
||||
|
||||
class TimeRequest:
|
||||
# TODO: Not implemented yet
|
||||
@@ -1044,6 +1057,7 @@ class TimeParameterFactory:
|
||||
def parse(self, data, message_info):
|
||||
return TimeParameter()
|
||||
|
||||
|
||||
class MleCommand(object):
|
||||
|
||||
def __init__(self, _type, tlvs):
|
||||
|
||||
@@ -45,6 +45,8 @@ class otCli:
|
||||
self.verbose = int(float(os.getenv('VERBOSE', 0)))
|
||||
self.node_type = os.getenv('NODE_TYPE', 'sim')
|
||||
self.simulator = simulator
|
||||
if self.simulator:
|
||||
self.simulator.add_node(self)
|
||||
|
||||
mode = os.environ.get('USE_MTD') is '1' and is_mtd and 'mtd' or 'ftd'
|
||||
|
||||
@@ -152,6 +154,56 @@ class otCli:
|
||||
self.pexpect.expect(pexpect.EOF)
|
||||
self.pexpect = None
|
||||
|
||||
def read_cert_messages_in_commissioning_log(self, timeout=-1):
|
||||
"""Get the log of the traffic after DTLS handshake.
|
||||
"""
|
||||
format_str = br"=+?\[\[THCI\].*?type=%s.*?\].*?=+?[\s\S]+?-{40,}"
|
||||
join_fin_req = format_str % br"JOIN_FIN\.req"
|
||||
join_fin_rsp = format_str % br"JOIN_FIN\.rsp"
|
||||
dummy_format_str = br"\[THCI\].*?type=%s.*?"
|
||||
join_ent_ntf = dummy_format_str % br"JOIN_ENT\.ntf"
|
||||
join_ent_rsp = dummy_format_str % br"JOIN_ENT\.rsp"
|
||||
pattern = (b"(" + join_fin_req + b")|(" + join_fin_rsp + b")|("+ join_ent_ntf + b")|(" + join_ent_rsp + b")")
|
||||
|
||||
messages = []
|
||||
# There are at most 4 cert messages both for joiner and commissioner
|
||||
for _ in range(0, 4):
|
||||
try:
|
||||
self._expect(pattern, timeout=timeout)
|
||||
log = self.pexpect.match.group(0)
|
||||
messages.append(self._extract_cert_message(log))
|
||||
except:
|
||||
break
|
||||
return messages
|
||||
|
||||
def _extract_cert_message(self, log):
|
||||
res = re.search(br"direction=\w+", log)
|
||||
assert res
|
||||
direction = res.group(0).split(b'=')[1].strip()
|
||||
|
||||
res = re.search(br"type=\S+", log)
|
||||
assert res
|
||||
type = res.group(0).split(b'=')[1].strip()
|
||||
|
||||
payload = bytearray([])
|
||||
payload_len = 0
|
||||
if type in [b"JOIN_FIN.req", b"JOIN_FIN.rsp"]:
|
||||
res = re.search(br"len=\d+", log)
|
||||
assert res
|
||||
payload_len = int(res.group(0).split(b'=')[1].strip())
|
||||
|
||||
hex_pattern = br"\|(\s([0-9a-fA-F]{2}|\.\.))+?\s+?\|"
|
||||
while True:
|
||||
res = re.search(hex_pattern, log)
|
||||
if not res:
|
||||
break
|
||||
data = [int(hex, 16) for hex in res.group(0)[1:-1].split(b' ') if hex and hex != b'..']
|
||||
payload += bytearray(data)
|
||||
log = log[res.end()-1:]
|
||||
|
||||
assert len(payload) == payload_len
|
||||
return (direction, type, payload)
|
||||
|
||||
def send_command(self, cmd, go=True):
|
||||
print("%d: %s" % (self.nodeid, cmd))
|
||||
self.pexpect.send(cmd + '\n')
|
||||
@@ -164,7 +216,7 @@ class otCli:
|
||||
self._expect('Commands:')
|
||||
commands = []
|
||||
while True:
|
||||
i = self._expect(['Done', '(\S+)'])
|
||||
i = self._expect(['Done', r'(\S+)'])
|
||||
if i != 0:
|
||||
commands.append(self.pexpect.match.groups()[0])
|
||||
else:
|
||||
@@ -280,7 +332,7 @@ class otCli:
|
||||
|
||||
def get_channel(self):
|
||||
self.send_command('channel')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
channel = int(self.pexpect.match.groups()[0])
|
||||
self._expect('Done')
|
||||
@@ -306,7 +358,7 @@ class otCli:
|
||||
|
||||
def get_key_sequence_counter(self):
|
||||
self.send_command('keysequence counter')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
key_sequence_counter = int(self.pexpect.match.groups()[0])
|
||||
self._expect('Done')
|
||||
@@ -330,7 +382,7 @@ class otCli:
|
||||
def get_network_name(self):
|
||||
self.send_command('networkname')
|
||||
while True:
|
||||
i = self._expect(['Done', '(\S+)'])
|
||||
i = self._expect(['Done', r'(\S+)'])
|
||||
if i != 0:
|
||||
network_name = self.pexpect.match.groups()[0].decode('utf-8')
|
||||
else:
|
||||
@@ -357,7 +409,7 @@ class otCli:
|
||||
|
||||
def get_partition_id(self):
|
||||
self.send_command('leaderpartitionid')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
weight = self.pexpect.match.groups()[0]
|
||||
self._expect('Done')
|
||||
@@ -397,7 +449,7 @@ class otCli:
|
||||
|
||||
def get_timeout(self):
|
||||
self.send_command('childtimeout')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
timeout = self.pexpect.match.groups()[0]
|
||||
self._expect('Done')
|
||||
@@ -415,7 +467,7 @@ class otCli:
|
||||
|
||||
def get_weight(self):
|
||||
self.send_command('leaderweight')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
weight = self.pexpect.match.groups()[0]
|
||||
self._expect('Done')
|
||||
@@ -436,7 +488,7 @@ class otCli:
|
||||
self.send_command('ipaddr')
|
||||
|
||||
while True:
|
||||
i = self._expect(['(\S+(:\S*)+)\r?\n', 'Done'])
|
||||
i = self._expect([r'(\S+(:\S*)+)\r?\n', 'Done'])
|
||||
if i == 0:
|
||||
addrs.append(self.pexpect.match.groups()[0].decode("utf-8"))
|
||||
elif i == 1:
|
||||
@@ -464,7 +516,7 @@ class otCli:
|
||||
self.send_command('eidcache')
|
||||
|
||||
while True:
|
||||
i = self._expect(['([a-fA-F0-9\:]+) ([a-fA-F0-9]+)\r?\n', 'Done'])
|
||||
i = self._expect([r'([a-fA-F0-9\:]+) ([a-fA-F0-9]+)\r?\n', 'Done'])
|
||||
if i == 0:
|
||||
eid = self.pexpect.match.groups()[0].decode("utf-8")
|
||||
rloc = self.pexpect.match.groups()[1].decode("utf-8")
|
||||
@@ -553,7 +605,7 @@ class otCli:
|
||||
|
||||
def get_context_reuse_delay(self):
|
||||
self.send_command('contextreusedelay')
|
||||
i = self._expect('(\d+)\r?\n')
|
||||
i = self._expect(r'(\d+)\r?\n')
|
||||
if i == 0:
|
||||
timeout = self.pexpect.match.groups()[0]
|
||||
self._expect('Done')
|
||||
@@ -617,7 +669,7 @@ class otCli:
|
||||
|
||||
results = []
|
||||
while True:
|
||||
i = self._expect(['\|\s(\S+)\s+\|\s(\S+)\s+\|\s([0-9a-fA-F]{4})\s\|\s([0-9a-fA-F]{16})\s\|\s(\d+)\r?\n',
|
||||
i = self._expect([r'\|\s(\S+)\s+\|\s(\S+)\s+\|\s([0-9a-fA-F]{4})\s\|\s([0-9a-fA-F]{16})\s\|\s(\d+)\r?\n',
|
||||
'Done'])
|
||||
if i == 0:
|
||||
results.append(self.pexpect.match.groups())
|
||||
@@ -640,7 +692,7 @@ class otCli:
|
||||
try:
|
||||
responders = {}
|
||||
while len(responders) < num_responses:
|
||||
i = self._expect(['from (\S+):'])
|
||||
i = self._expect([r'from (\S+):'])
|
||||
if i == 0:
|
||||
responders[self.pexpect.match.groups()[0]] = 1
|
||||
self._expect('\n')
|
||||
|
||||
@@ -76,6 +76,7 @@ class PcapCodec(object):
|
||||
timestamp = self._get_timestamp()
|
||||
pkt = self.encode_frame(frame, *timestamp)
|
||||
self._pcap_file.write(pkt)
|
||||
self._pcap_file.flush()
|
||||
|
||||
def __del__(self):
|
||||
self._pcap_file.close()
|
||||
|
||||
@@ -39,6 +39,8 @@ import time
|
||||
|
||||
import io
|
||||
import config
|
||||
import dtls
|
||||
import mesh_cop
|
||||
import message
|
||||
import pcap
|
||||
|
||||
@@ -46,9 +48,47 @@ def dbg_print(*args):
|
||||
if False:
|
||||
print(args)
|
||||
|
||||
class RealTime:
|
||||
class BaseSimulator(object):
|
||||
|
||||
def __init__(self):
|
||||
self._nodes = {}
|
||||
self.commissioning_messages = {}
|
||||
self._payload_parse_factory = mesh_cop.MeshCopCommandFactory(mesh_cop.create_default_mesh_cop_tlv_factories())
|
||||
self._mesh_cop_msg_set = mesh_cop.create_mesh_cop_message_type_set()
|
||||
|
||||
def __del__(self):
|
||||
self._nodes = None
|
||||
|
||||
def add_node(self, node):
|
||||
self._nodes[node.nodeid] = node
|
||||
self.commissioning_messages[node.nodeid] = []
|
||||
|
||||
def set_lowpan_context(self, cid, prefix):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_messages_sent_by(self, nodeid):
|
||||
raise NotImplementedError
|
||||
|
||||
def go(self, duration, nodeid=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def read_cert_messages_in_commissioning_log(self, nodeids):
|
||||
for nodeid in nodeids:
|
||||
node = self._nodes[nodeid]
|
||||
if not node:
|
||||
continue
|
||||
for direction, type, payload in node.read_cert_messages_in_commissioning_log():
|
||||
if direction == b'send':
|
||||
msg = self._payload_parse_factory.parse(type.decode("utf-8"), io.BytesIO(payload))
|
||||
self.commissioning_messages[nodeid].append(msg)
|
||||
|
||||
class RealTime(BaseSimulator):
|
||||
|
||||
def __init__(self):
|
||||
super(RealTime, self).__init__()
|
||||
self._sniffer = config.create_default_thread_sniffer()
|
||||
self._sniffer.start()
|
||||
|
||||
@@ -56,7 +96,10 @@ class RealTime:
|
||||
self._sniffer.set_lowpan_context(cid, prefix)
|
||||
|
||||
def get_messages_sent_by(self, nodeid):
|
||||
return self._sniffer.get_messages_sent_by(nodeid)
|
||||
messages = self._sniffer.get_messages_sent_by(nodeid).messages
|
||||
ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
|
||||
self.commissioning_messages[nodeid] = []
|
||||
return ret
|
||||
|
||||
def go(self, duration, nodeid=None):
|
||||
time.sleep(duration)
|
||||
@@ -64,7 +107,7 @@ class RealTime:
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
class VirtualTime:
|
||||
class VirtualTime(BaseSimulator):
|
||||
|
||||
OT_SIM_EVENT_ALARM_FIRED = 0
|
||||
OT_SIM_EVENT_RADIO_RECEIVED = 1
|
||||
@@ -91,6 +134,7 @@ class VirtualTime:
|
||||
NCP_SIM = os.getenv('NODE_TYPE', 'sim') == 'ncp-sim'
|
||||
|
||||
def __init__(self):
|
||||
super(VirtualTime, self).__init__()
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
ip = '127.0.0.1'
|
||||
@@ -126,14 +170,13 @@ class VirtualTime:
|
||||
|
||||
# Ignore any exceptions
|
||||
try:
|
||||
msg = self._message_factory.create(io.BytesIO(message))
|
||||
|
||||
if msg is not None:
|
||||
self.devices[addr]['msgs'].append(msg)
|
||||
messages = self._message_factory.create(io.BytesIO(message))
|
||||
self.devices[addr]['msgs'] += messages
|
||||
|
||||
except Exception as e:
|
||||
# Just print the exception to the console
|
||||
print("EXCEPTION: %s" % e)
|
||||
traceback.print_exc()
|
||||
|
||||
def set_lowpan_context(self, cid, prefix):
|
||||
self._message_factory.set_lowpan_context(cid, prefix)
|
||||
@@ -154,7 +197,9 @@ class VirtualTime:
|
||||
messages = self.devices[addr]['msgs']
|
||||
self.devices[addr]['msgs'] = []
|
||||
|
||||
return message.MessagesSet(messages)
|
||||
ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
|
||||
self.commissioning_messages[nodeid] = []
|
||||
return ret
|
||||
|
||||
def _is_radio(self, addr):
|
||||
return addr[1] < self.BASE_PORT * 2
|
||||
|
||||
@@ -33,6 +33,7 @@ import logging
|
||||
import os
|
||||
import pcap
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import Queue
|
||||
@@ -85,16 +86,15 @@ class Sniffer:
|
||||
|
||||
# Ignore any exceptions
|
||||
try:
|
||||
msg = self._message_factory.create(io.BytesIO(data))
|
||||
|
||||
if msg is not None:
|
||||
self.logger.debug("Received message: {}".format(msg))
|
||||
messages = self._message_factory.create(io.BytesIO(data))
|
||||
self.logger.debug("Received messages: {}".format(messages))
|
||||
for msg in messages:
|
||||
self._buckets[nodeid].put(msg)
|
||||
|
||||
except Exception as e:
|
||||
# Just print the exception to the console
|
||||
print("EXCEPTION: %s" % e)
|
||||
pass
|
||||
traceback.print_exc()
|
||||
|
||||
self.logger.debug("Sniffer stopped.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user