[tests] support testing commissioning (#3441)

Add support for testing commissioning process, including:

1. Add DTLS, Thread Discovery and commissioning messages(JOIN_FIN.req etc) parsing;
2. Support parsing log from UART response, and construct decrypted messages;
3. Common commands for commissioning related verification;
This commit is contained in:
wgtdkp
2019-02-27 07:49:57 +08:00
committed by Jonathan Hui
parent cc37bd1dfb
commit dc947f0152
13 changed files with 1420 additions and 47 deletions
+10 -10
View File
@@ -416,20 +416,20 @@ build_samr21() {
export ASAN_OPTIONS=symbolize=1 || die
export DISTCHECK_CONFIGURE_FLAGS= CPPFLAGS=-DOPENTHREAD_POSIX_VIRTUAL_TIME=1 || die
./bootstrap || die
make -f examples/Makefile-posix distcheck || die
CERT_LOG=1 make -f examples/Makefile-posix distcheck || die
}
[ $BUILD_TARGET != posix-32-bit ] || {
./bootstrap || die
COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 make -f examples/Makefile-posix check || die
CERT_LOG=1 COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 make -f examples/Makefile-posix check || die
}
[ $BUILD_TARGET != posix-app-cli ] || {
./bootstrap || die
# enable code coverage for OpenThread transceiver only
COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
COVERAGE=1 make -f src/posix/Makefile-posix || die
COVERAGE=1 PYTHONUNBUFFERED=1 OT_CLI_PATH="$(pwd)/$(ls output/posix/*/bin/ot-cli)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" make -f src/posix/Makefile-posix check || die
CERT_LOG=1 COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
CERT_LOG=1 COVERAGE=1 make -f src/posix/Makefile-posix || die
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 OT_CLI_PATH="$(pwd)/$(ls output/posix/*/bin/ot-cli)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" make -f src/posix/Makefile-posix check || die
}
[ $BUILD_TARGET != posix-app-pty ] || {
@@ -439,7 +439,7 @@ build_samr21() {
[ $BUILD_TARGET != posix-mtd ] || {
./bootstrap || die
COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 USE_MTD=1 make -f examples/Makefile-posix check || die
CERT_LOG=1 COVERAGE=1 CFLAGS=-m32 CXXFLAGS=-m32 LDFLAGS=-m32 USE_MTD=1 make -f examples/Makefile-posix check || die
}
[ $BUILD_TARGET != posix-ncp-spi ] || {
@@ -449,15 +449,15 @@ build_samr21() {
[ $BUILD_TARGET != posix-app-ncp ] || {
./bootstrap || die
COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
CERT_LOG=1 COVERAGE=1 VIRTUAL_TIME_UART=1 make -f examples/Makefile-posix || die
# enable code coverage for OpenThread posix radio
COVERAGE=1 make -f src/posix/Makefile-posix || die
COVERAGE=1 PYTHONUNBUFFERED=1 OT_NCP_PATH="$(pwd)/$(ls output/posix/*/bin/ot-ncp)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" NODE_TYPE=ncp-sim make -f src/posix/Makefile-posix check || die
CERT_LOG=1 COVERAGE=1 make -f src/posix/Makefile-posix || die
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 OT_NCP_PATH="$(pwd)/$(ls output/posix/*/bin/ot-ncp)" RADIO_DEVICE="$(pwd)/$(ls output/*/bin/ot-ncp-radio)" NODE_TYPE=ncp-sim make -f src/posix/Makefile-posix check || die
}
[ $BUILD_TARGET != posix-ncp ] || {
./bootstrap || die
COVERAGE=1 PYTHONUNBUFFERED=1 NODE_TYPE=ncp-sim make -f examples/Makefile-posix check || die
CERT_LOG=1 COVERAGE=1 PYTHONUNBUFFERED=1 NODE_TYPE=ncp-sim make -f examples/Makefile-posix check || die
}
[ $BUILD_TARGET != toranj-test-framework ] || {
+1
View File
@@ -35,6 +35,7 @@ AM_DISTCHECK_CONFIGURE_FLAGS = \
--enable-application-coap \
--enable-application-coap-secure \
--enable-border-router \
--enable-cert-log \
--enable-cli \
--enable-commissioner \
--enable-dhcp6-client \
+2
View File
@@ -125,9 +125,11 @@ EXTRA_DIST = \
command.py \
common.py \
config.py \
dtls.py \
ipv6.py \
lowpan.py \
mac802154.py \
mesh_cop.py \
message.py \
mle.py \
net_crypto.py \
+73
View File
@@ -33,11 +33,14 @@ import sys
import ipv6
import network_data
import network_layer
import common
import config
import mesh_cop
import mle
from collections import Counter
from enum import IntEnum
from network_data import Prefix, BorderRouter, LowpanId
class CheckType(IntEnum):
CONTAIN = 0
@@ -270,6 +273,7 @@ def check_parent_request(command_msg, is_first_request):
elif not scan_mask.end_device:
raise ValueError("Second parent request without E bit set")
def check_parent_response(command_msg, mle_frame_counter = CheckType.OPTIONAL):
"""Verify a properly formatted Parent Response command message.
"""
@@ -478,3 +482,72 @@ def check_address_registration_tlv(addr_reg_tlv, address_set):
"""Verify all addresses contained in address_set are contained in add_reg_tlv
"""
assert all(addr in addr_reg_tlv.addresses for addr in address_set), 'Some addresses are not included in AddressRegistration TLV'
def assert_contains_tlv(tlvs, check_type, tlv_type):
"""Assert a tlv list contains specific tlv and return the first qualified.
"""
tlvs = [tlv for tlv in tlvs if isinstance(tlv, tlv_type)]
if check_type is CheckType.CONTAIN:
assert tlvs
return tlvs[0]
elif check_type is CheckType.NOT_CONTAIN:
assert not tlvs
return None
elif check_type is CheckType.OPTIONAL:
return None
else:
raise ValueError("Invalid check type: {}".format(check_type))
def check_discovery_request(command_msg):
"""Verify a properly formatted Thread Discovery Request command message.
"""
assert not isinstance(command_msg.mle, mle.MleMessageSecured)
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
request = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryRequest)
assert request.version == config.PROTOCOL_VERSION
def check_discovery_response(command_msg, request_src_addr, steering_data=CheckType.OPTIONAL):
"""Verify a properly formatted Thread Discovery Response command message.
"""
assert not isinstance(command_msg.mle, mle.MleMessageSecured)
assert command_msg.mac_header.src_address.type == common.MacAddressType.LONG
assert command_msg.mac_header.dest_address == request_src_addr
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
response = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryResponse)
assert response.version == config.PROTOCOL_VERSION
assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.ExtendedPanid)
assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.NetworkName)
assert_contains_tlv(tlvs, steering_data, network_data.SteeringData)
assert_contains_tlv(tlvs, steering_data, mesh_cop.JoinerUdpPort)
check_type = CheckType.CONTAIN if response.native_flag else CheckType.OPTIONAL
assert_contains_tlv(tlvs, check_type, network_data.CommissionerUdpPort)
def get_joiner_udp_port_in_discovery_response(command_msg):
"""Get the udp port specified in a DISCOVERY RESPONSE message
"""
tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
udp_port_tlv = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.JoinerUdpPort)
return udp_port_tlv.udp_port
def check_joiner_commissioning_messages(commissioning_messages):
"""Verify COAP messages sent by joiner while commissioning process.
"""
print(commissioning_messages)
assert len(commissioning_messages) >= 2
join_fin_req = commissioning_messages[0]
assert join_fin_req.type == mesh_cop.MeshCopMessageType.JOIN_FIN_REQ
assert_contains_tlv(join_fin_req.tlvs, CheckType.NOT_CONTAIN, mesh_cop.ProvisioningUrl)
join_ent_rsp = commissioning_messages[1]
assert join_ent_rsp.type == mesh_cop.MeshCopMessageType.JOIN_ENT_RSP
def check_commissioner_commissioning_messages(commissioning_messages):
"""Verify COAP messages sent by commissioner while commissioning process.
"""
assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_FIN_RSP for msg in commissioning_messages)
def check_joiner_router_commissioning_messages(commissioning_messages):
"""Verify COAP messages sent by joiner router while commissioning process.
"""
assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_ENT_NTF for msg in commissioning_messages)
+30 -2
View File
@@ -30,8 +30,10 @@ import os
from enum import Enum
import coap
import dtls
import ipv6
import lowpan
import mesh_cop
import message
import mle
import net_crypto
@@ -75,6 +77,8 @@ VIRTUAL_TIME = int(os.getenv('VIRTUAL_TIME', 0))
LEADER_NOTIFY_SED_BY_CHILD_UPDATE_REQUEST = True
PROTOCOL_VERSION = 2
def create_default_network_data_prefix_sub_tlvs_factories():
return {
network_data.TlvType.HAS_ROUTE: network_data.HasRouteFactory(
@@ -150,6 +154,24 @@ def create_default_mle_tlv_address_registration_factory():
addr_compressed_factory=mle.AddressCompressedFactory(),
addr_full_factory=mle.AddressFullFactory())
def create_default_mle_tlv_thread_discovery_factory():
return mle.ThreadDiscoveryFactory(
thread_discovery_tlvs_factory=create_default_thread_discovery_tlvs_factory())
def create_default_thread_discovery_tlvs_factory():
return mesh_cop.ThreadDiscoveryTlvsFactory(
sub_tlvs_factories=create_default_thread_discovery_sub_tlvs_factories())
def create_default_thread_discovery_sub_tlvs_factories():
return {
mesh_cop.TlvType.DISCOVERY_REQUEST: mesh_cop.DiscoveryRequestFactory(),
mesh_cop.TlvType.DISCOVERY_RESPONSE: mesh_cop.DiscoveryResponseFactory(),
mesh_cop.TlvType.EXTENDED_PANID: mesh_cop.ExtendedPanidFactory(),
mesh_cop.TlvType.NETWORK_NAME: mesh_cop.NetworkNameFactory(),
mesh_cop.TlvType.STEERING_DATA: network_data.SteeringDataFactory(),
mesh_cop.TlvType.JOINER_UDP_PORT: mesh_cop.JoinerUdpPortFactory(),
mesh_cop.TlvType.COMMISSIONER_UDP_PORT: network_data.CommissionerUdpPortFactory()
}
def create_default_mle_tlvs_factories():
return {
@@ -177,9 +199,9 @@ def create_default_mle_tlvs_factories():
mle.TlvType.PENDING_TIMESTAMP: mle.PendingTimestampFactory(),
mle.TlvType.ACTIVE_OPERATIONAL_DATASET: mle.ActiveOperationalDatasetFactory(),
mle.TlvType.PENDING_OPERATIONAL_DATASET: mle.PendingOperationalDatasetFactory(),
mle.TlvType.THREAD_DISCOVERY: mle.ThreadDiscoveryFactory(),
mle.TlvType.TIME_REQUEST: mle.TimeRequestFactory(),
mle.TlvType.TIME_PARAMETER: mle.TimeParameterFactory(),
mle.TlvType.THREAD_DISCOVERY: create_default_mle_tlv_thread_discovery_factory()
}
@@ -253,15 +275,21 @@ def create_default_ipv6_hop_by_hop_options_factory():
def create_default_based_on_src_dst_ports_udp_payload_factory(master_key):
mle_message_factory = create_default_mle_message_factory(master_key)
coap_message_factory = create_default_coap_message_factory()
dtls_message_factory = create_default_dtls_message_factory()
return ipv6.UdpBasedOnSrcDstPortsPayloadFactory(
src_dst_port_based_payload_factories={
19788: mle_message_factory,
61631: coap_message_factory
61631: coap_message_factory,
1000: dtls_message_factory
}
)
def create_default_dtls_message_factory():
return dtls.MessageFactory()
def create_default_ipv6_icmp_body_factories():
return {
ipv6.ICMP_DESTINATION_UNREACHABLE: ipv6.ICMPv6DestinationUnreachableFactory(),
+622
View File
@@ -0,0 +1,622 @@
#!/usr/bin/env python
#
# Copyright (c) 2019, The OpenThread Authors.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
from binascii import hexlify
from enum import IntEnum
from functools import reduce
import io
import math
import struct
from ipv6 import BuildableFromBytes
from ipv6 import ConvertibleToBytes
class HandshakeType(IntEnum):
HELLO_REQUEST = 0
CLIENT_HELLO = 1
SERVER_HELLO = 2
HELLO_VERIFY_REQUEST = 3
CERTIFICATE = 11
SERVER_KEY_EXCHANGE = 12
CERTIFICATE_REQUEST = 13
SERVER_HELLO_DONE = 14
CERTIFICATE_VERIFY = 15
CLIENT_KEY_EXCHANGE = 16
FINISHED = 20
class ContentType(IntEnum):
CHANGE_CIPHER_SPEC = 20
ALERT = 21
HANDSHAKE = 22
APPLICATION_DATA = 23
class AlertLevel(IntEnum):
WARNING = 1
FATAL = 2
class AlertDescription(IntEnum):
CLOSE_NOTIFY = 0
UNEXPECTED_MESSAGE = 10
BAD_RECORD_MAC = 20
DECRYPTION_FAILED_RESERVED = 21
RECORD_OVERFLOW = 22
DECOMPRESSION_FAILURE = 30
HANDSHAKE_FAILURE = 40
NO_CERTIFICATE_RESERVED = 41
BAD_CERTIFICATE = 42
UNSUPPORTED_CERTIFICATE = 43
CERTIFICATE_REVOKED = 44
CERTIFICATE_EXPIRED = 45
CERTIFICATE_UNKNOWN = 46
ILLEGAL_PARAMETER = 47
UNKNOWN_CA = 48
ACCESS_DENIED = 49
DECODE_ERROR = 50
DECRYPT_ERROR = 51
EXPORT_RESTRICTION_RESERVED = 60
PROTOCOL_VERSION = 70
INSUFFICIENT_SECURITY = 71
INTERNAL_ERROR = 80
USER_CANCELED = 90
NO_RENEGOTIATION = 100
UNSUPPORTED_EXTENSION = 110
class Record(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, content_type, version, epoch,
sequence_number, length, fragment):
self.content_type = content_type
self.version = version
self.epoch = epoch
self.sequence_number = sequence_number
self.length = length
self.fragment = fragment
def to_bytes(self):
return (struct.pack(">B", self.content_type) +
self.version.to_bytes() +
struct.pack(">H", self.epoch) +
self.sequence_number.to_bytes(6, byteorder='big') +
struct.pack(">H", self.length) +
self.fragment)
@classmethod
def from_bytes(cls, data):
content_type = ContentType(struct.unpack(">B", data.read(1))[0])
version = ProtocolVersion.from_bytes(data)
epoch = struct.unpack(">H", data.read(2))[0]
sequence_number = struct.unpack(">Q", b'\x00\x00' + data.read(6))[0]
length = struct.unpack(">H", data.read(2))[0]
fragment = bytes(data.read(length))
return cls(content_type, version, epoch, sequence_number, length, fragment)
def __repr__(self):
return "Record(content_type={}, version={}, epoch={}, sequence_number={}, length={})".format(
str(self.content_type), self.version, self.epoch, self.sequence_number, self.length)
class Message(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, content_type):
self.content_type = content_type
def to_bytes(self):
raise NotImplementedError
@classmethod
def from_bytes(cls, data):
raise NotImplementedError
class HandshakeMessage(Message):
def __init__(self, handshake_type, length, message_seq,
fragment_offset, fragment_length, body):
super(HandshakeMessage, self).__init__(ContentType.HANDSHAKE)
self.handshake_type = handshake_type
self.length = length
self.message_seq = message_seq
self.fragment_offset = fragment_offset
self.fragment_length = fragment_length
self.body = body
def to_bytes(self):
return (struct.pack(">B", self.handshake_type) +
struct.pack(">I", self.length)[1:] +
struct.pack(">H", self.message_seq) +
struct.pack(">I", self.fragment_offset)[1:] +
struct.pack(">I", self.fragment_length)[1:] +
self.body.to_bytes())
@classmethod
def from_bytes(cls, data):
handshake_type = HandshakeType(struct.unpack(">B", data.read(1))[0])
length = struct.unpack(">I", b'\x00' + data.read(3))[0]
message_seq = struct.unpack(">H", data.read(2))[0]
fragment_offset = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
fragment_length = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
end_position = data.tell() + fragment_length
# TODO(wgtdkp): handle fragmentation
message_class, body = handshake_map[handshake_type], None
if message_class:
body = message_class.from_bytes(data)
else:
print("{} messages are not handled".format(str(handshake_type)))
body = bytes(data.read(fragment_length))
assert data.tell() == end_position
return cls(handshake_type, length, message_seq,
fragment_offset, fragment_length, body)
def __repr__(self):
return "Handshake(type={}, length={})".format(str(self.handshake_type), self.length)
class ProtocolVersion(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, major, minor):
self.major = major
self.minor = minor
def __eq__(self, other):
return (type(self) == type(other) and
self.major == other.major and
self.minor == other.minor)
def to_bytes(self):
return struct.pack(">BB", self.major, self.minor)
@classmethod
def from_bytes(cls, data):
major, minor = struct.unpack(">BB", data.read(2))
return cls(major, minor)
def __repr__(self):
return "ProtocolVersion(major={}, minor={})".format(self.major, self.minor)
class Random(ConvertibleToBytes, BuildableFromBytes):
random_bytes_length = 28
def __init__(self, gmt_unix_time, random_bytes):
self.gmt_unix_time = gmt_unix_time
self.random_bytes = random_bytes
assert len(self.random_bytes) == Random.random_bytes_length
def __eq__(self, other):
return (type(self) == type(other) and
self.gmt_unix_time == other.gmt_unix_time and
self.random_bytes == other.random_bytes)
def to_bytes(self):
return struct.pack(">I", self.gmt_unix_time) + (self.random_bytes)
@classmethod
def from_bytes(cls, data):
gmt_unix_time = struct.unpack(">I", data.read(4))[0]
random_bytes = bytes(data.read(cls.random_bytes_length))
return cls(gmt_unix_time, random_bytes)
class VariableVector(ConvertibleToBytes):
def __init__(self, subrange, ele_cls, elements):
self.subrange = subrange
self.ele_cls = ele_cls
self.elements = elements
assert self.subrange[0] <= len(self.elements) <= self.subrange[1]
def length(self):
return len(self.elements)
def __eq__(self, other):
return (type(self) == type(other) and
self.subrange == other.subrange and
self.ele_cls == other.ele_cls and
self.elements == other.elements)
def to_bytes(self):
data = reduce(lambda ele, acc: acc + ele.to_bytes(), self.elements)
return VariableVector._encode_length(len(data), self.subrange) + data
@classmethod
def from_bytes(cls, ele_cls, subrange, data):
length = cls._decode_length(subrange, data)
end_position = data.tell() + length
elements = []
while data.tell() < end_position:
elements.append(ele_cls.from_bytes(data))
return cls(subrange, ele_cls, elements)
@classmethod
def _decode_length(cls, subrange, data):
length_in_byte = cls._calc_length_in_byte(subrange[1])
return reduce(lambda acc, byte: (acc << 8) | byte, bytearray(data.read(length_in_byte)), 0)
@classmethod
def _encode_length(cls, length, subrange):
length_in_byte = cls._calc_length_in_byte(subrange[1])
ret = bytearray([])
while length_in_byte > 0:
ret += bytes(length_in_byte & 0xff)
length_in_byte == length_in_byte >> 8
return ret
@classmethod
def _calc_length_in_byte(cls, ceiling):
return (ceiling.bit_length() + 7) // 8
class Opaque(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, byte):
self.byte = byte
def __eq__(self, other):
return type(self) == type(other) and self.byte == other.byte
def to_bytes(self):
return struct.pack(">B", self.byte)
@classmethod
def from_bytes(cls, data):
return cls(struct.unpack(">B", data.read(1))[0])
class CipherSuite(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, cipher):
self.cipher = cipher
def __eq__(self, other):
return type(self) == type(other) and self.cipher == other.cipher
def to_bytes(self):
return struct.pack(">BB", self.cipher[0], self.cipher[1])
@classmethod
def from_bytes(cls, data):
return cls(struct.unpack(">BB", data.read(2)))
def __repr__(self):
return "CipherSuite({}, {})".format(self.cipher[0], self.cipher[1])
class CompressionMethod(ConvertibleToBytes, BuildableFromBytes):
NULL = 0
def __init__(self):
pass
def __eq__(self, other):
return type(self) == type(other)
def to_bytes(self):
return struct.pack(">B", CompressionMethod.NULL)
@classmethod
def from_bytes(cls, data):
method = struct.unpack(">B", data.read(1))[0]
assert method == cls.NULL
return cls()
class Extension(ConvertibleToBytes, BuildableFromBytes):
def __init__(self, extension_type, extension_data):
self.extension_type = extension_type
self.extension_data = extension_data
def __eq__(self, other):
return (type(self) == type(other) and
self.extension_type == other.extension_type and
self.extension_data == other.extension_data)
def to_bytes(self):
return (struct.pack(">H", self.extension_type) +
self.extension_data.to_bytes())
@classmethod
def from_bytes(cls, data):
extension_type = struct.unpack(">H", data.read(2))[0]
extension_data = VariableVector.from_bytes(Opaque, (0, 2**16 - 1), data)
return cls(extension_type, extension_data)
class ClientHello(HandshakeMessage):
def __init__(self, client_version, random, session_id,
cookie, cipher_suites, compression_methods, extensions):
self.client_version = client_version
self.random = random
self.session_id = session_id
self.cookie = cookie
self.cipher_suites = cipher_suites
self.compression_methods = compression_methods
self.extensions = extensions
def to_bytes(self):
return (self.client_version.to_bytes() +
self.random.to_bytes() +
self.session_id.to_bytes() +
self.cookie.to_bytes() +
self.cipher_suites.to_bytes() +
self.compression_methods.to_bytes() +
self.extensions.to_bytes())
@classmethod
def from_bytes(cls, data):
client_version = ProtocolVersion.from_bytes(data)
random = Random.from_bytes(data)
session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
cipher_suites = VariableVector.from_bytes(CipherSuite, (2, 2**16 - 1), data)
compression_methods = VariableVector.from_bytes(CompressionMethod, (1, 2**8 - 1), data)
extensions = None
if data.tell() < len(data.getvalue()):
extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
return cls(client_version, random, session_id,
cookie, cipher_suites, compression_methods, extensions)
class HelloVerifyRequest(HandshakeMessage):
def __init__(self, server_version, cookie):
self.server_version = server_version
self.cookie = cookie
def to_bytes(self):
return self.server_version.to_bytes() + self.cookie.to_bytes()
@classmethod
def from_bytes(cls, data):
server_version = ProtocolVersion.from_bytes(data)
cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
return cls(server_version, cookie)
class ServerHello(HandshakeMessage):
def __init__(self, server_version, random, session_id,
cipher_suite, compression_method, extensions):
self.server_version = server_version
self.random = random
self.session_id = session_id
self.cipher_suite = cipher_suite
self.compression_method = compression_method
self.extensions = extensions
def to_bytes(self):
return (self.server_version.to_bytes() +
self.random.to_bytes() +
self.session_id.to_bytes() +
self.cipher_suite.to_bytes() +
self.compression_method.to_bytes() +
self.extensions.to_bytes())
@classmethod
def from_bytes(cls, data):
server_version = ProtocolVersion.from_bytes(data)
random = Random.from_bytes(data)
session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
cipher_suite = CipherSuite.from_bytes(data)
compression_method = CompressionMethod.from_bytes(data)
extensions = None
if data.tell() < len(data.getvalue()):
extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
return cls(server_version, random, session_id,
cipher_suite, compression_method, extensions)
class ServerHelloDone(HandshakeMessage):
def __init__(self):
pass
def to_bytes(self):
return bytearray([])
@classmethod
def from_bytes(cls, data):
return cls()
class HelloRequest(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class Certificate(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class ServerKeyExchange(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class CertificateRequest(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class CertificateVerify(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class ClientKeyExchange(HandshakeMessage):
def __init__(self):
raise NotImplementedError
class Finished(HandshakeMessage):
def __init__(self, verify_data):
raise NotImplementedError
class AlertMessage(Message):
def __init__(self, level, description):
super(AlertMessage, self).__init__(ContentType.ALERT)
self.level = level
self.description = description
def to_bytes(self):
struct.pack(">BB", self.level, self.description)
@classmethod
def from_bytes(cls, data):
level, description = struct.unpack(">BB", data.read(2))
try:
return cls(AlertLevel(level), AlertDescription(description))
except:
data.read()
# An AlertMessage could be encrypted and we can't parsing it.
return cls(None, None)
def __repr__(self):
return "Alert(level={}, description={})".format(
str(self.level), str(self.description))
class ChangeCipherSpecMessage(Message):
def __init__(self):
super(ChangeCipherSpecMessage, self).__init__(ContentType.CHANGE_CIPHER_SPEC)
def to_bytes(self):
return struct.pack(">B", 1)
@classmethod
def from_bytes(cls, data):
assert struct.unpack(">B", data.read(1))[0] == 1
return cls()
def __repr__(self):
return "ChangeCipherSpec(value=1)"
class ApplicationDataMessage(Message):
def __init__(self, raw):
super(ApplicationDataMessage, self).__init__(ContentType.APPLICATION_DATA)
self.raw = raw
self.body = None
def to_bytes(self):
return self.raw
@classmethod
def from_bytes(cls, data):
# It is safe to read until the end of this byte stream, because
# there is single application data message in a record.
length = len(data.getvalue()) - data.tell()
return cls(bytes(data.read(length)))
def __repr__(self):
if self.body:
return "ApplicationData(body={})".format(self.body)
else:
return "ApplicationData(raw_length={})".format(len(self.raw))
handshake_map = {
HandshakeType.HELLO_REQUEST: None, # HelloRequest
HandshakeType.CLIENT_HELLO: ClientHello,
HandshakeType.SERVER_HELLO: ServerHello,
HandshakeType.HELLO_VERIFY_REQUEST: HelloVerifyRequest,
HandshakeType.CERTIFICATE: None, # Certificate
HandshakeType.SERVER_KEY_EXCHANGE: None, # ServerKeyExchange
HandshakeType.CERTIFICATE_REQUEST: None, # CertificateRequest
HandshakeType.SERVER_HELLO_DONE: ServerHelloDone,
HandshakeType.CERTIFICATE_VERIFY: None, # CertificateVerify
HandshakeType.CLIENT_KEY_EXCHANGE: None, # ClientKeyExchange
HandshakeType.FINISHED: None, # Finished
}
content_map = {
ContentType.CHANGE_CIPHER_SPEC: ChangeCipherSpecMessage,
ContentType.ALERT: AlertMessage,
ContentType.HANDSHAKE: HandshakeMessage,
ContentType.APPLICATION_DATA: ApplicationDataMessage
}
class MessageFactory(object):
last_msg_is_change_cipher_spec = False
def __init__(self):
pass
def parse(self, data, message_info):
messages = []
# Multiple records could be sent in the same UDP datagram
while data.tell() < len(data.getvalue()):
record = Record.from_bytes(data)
if record.version.major != 0xfe or record.version.minor != 0xfd:
raise ValueError("DTLS version error, expect DTLSv1.2")
last_msg_is_change_cipher_spec = type(self).last_msg_is_change_cipher_spec
type(self).last_msg_is_change_cipher_spec = (record.content_type == ContentType.CHANGE_CIPHER_SPEC)
# FINISHED message immediately follows CHANGE_CIPHER_SPEC message
# We skip FINISHED message as it is encrypted
if last_msg_is_change_cipher_spec:
continue
fragment_data = io.BytesIO(record.fragment)
# Multiple handshake messages could be sent in the same record
while fragment_data.tell() < len(fragment_data.getvalue()):
content_class = content_map[record.content_type]
assert content_class
messages.append(content_class.from_bytes(fragment_data))
return messages
+468
View File
@@ -0,0 +1,468 @@
#!/usr/bin/env python
#
# Copyright (c) 2019, The OpenThread Authors.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
from enum import IntEnum
import io
import struct
from network_data import SubTlvsFactory
class TlvType(IntEnum):
EXTENDED_PANID = 2
NETWORK_NAME = 3
STEERING_DATA = 8
COMMISSIONER_UDP_PORT = 15
STATE = 16
JOINER_UDP_PORT = 18
PROVISIONING_URL = 32
VENDOR_NAME = 33
VENDOR_MODEL = 34
VENDOR_SW_VERSION = 35
VENDOR_DATA = 36
VENDOR_STACK_VERSION = 37
DISCOVERY_REQUEST = 128
DISCOVERY_RESPONSE = 129
class MeshCopMessageType(IntEnum):
JOIN_FIN_REQ = 1
JOIN_FIN_RSP = 2
JOIN_ENT_NTF = 3
JOIN_ENT_RSP = 4
def create_mesh_cop_message_type_set():
return [ MeshCopMessageType.JOIN_FIN_REQ,
MeshCopMessageType.JOIN_FIN_RSP,
MeshCopMessageType.JOIN_ENT_NTF,
MeshCopMessageType.JOIN_ENT_RSP ]
class State(object):
def __init__(self, state):
self._state = state
@property
def state(self):
return self._state
def __eq__(self, other):
return self.state == other.state
def __repr__(self):
return "State(state={})".format(self.state)
class StateFactory:
def parse(self, data):
state = ord(data.read(1))
return State(state)
class VendorName(object):
def __init__(self, vendor_name):
self._vendor_name = vendor_name
@property
def vendor_name(self):
return self._vendor_name
def __eq__(self, other):
return self.vendor_name == other.vendor_name
def __repr__(self):
return "VendorName(vendor_name={})".format(self.vendor_name)
class VendorNameFactory:
def parse(self, data):
vendor_name = data.getvalue().decode('utf-8')
return VendorName(vendor_name)
class VendorModel(object):
def __init__(self, vendor_model):
self._vendor_model = vendor_model
@property
def vendor_model(self):
return self._vendor_model
def __eq__(self, other):
return self.vendor_model == other.vendor_model
def __repr__(self):
return "VendorModel(vendor_model={})".format(self.vendor_model)
class VendorModelFactory:
def parse(self, data):
vendor_model = data.getvalue().decode('utf-8')
return VendorModel(vendor_model)
class VendorSWVersion(object):
def __init__(self, vendor_sw_version):
self._vendor_sw_version = vendor_sw_version
@property
def vendor_sw_version(self):
return self._vendor_sw_version
def __eq__(self, other):
return self.vendor_sw_version == other.vendor_sw_version
def __repr__(self):
return "VendorName(vendor_sw_version={})".format(self.vendor_sw_version)
class VendorSWVersionFactory:
def parse(self, data):
vendor_sw_version = data.getvalue()
return VendorSWVersion(vendor_sw_version)
# VendorStackVersion TLV (37)
class VendorStackVersion(object):
def __init__(self, stack_vendor_oui, build, rev, minor, major):
self._stack_vendor_oui = stack_vendor_oui
self._build = build
self._rev = rev
self._minor = minor
self._major = major
return
@property
def stack_vendor_oui(self):
return self._stack_vendor_oui
@property
def build(self):
return self._build
@property
def rev(self):
return self._rev
@property
def minor(self):
return self._minor
@property
def major(self):
return self._major
def __repr__(self):
return "VendorStackVersion(vendor_stack_version={}, build={}, rev={}, minor={}, major={})".format(self.stack_vendor_oui, self.build, self.rev, self.minor, self.major)
class VendorStackVersionFactory:
def parse(self, data):
stack_vendor_oui = struct.unpack(">H", data.read(2))[0]
rest = struct.unpack(">BBBB", data.read(4))
build = rest[1] << 4 | (0xf0 & rest[2])
rev = 0xf & rest[2]
minor = rest[3] & 0xf0
major = rest[3] & 0xf
return VendorStackVersion(stack_vendor_oui, build, rev, minor, major)
class ProvisioningUrl(object):
def __init__(self, url):
self._url = url
@property
def url(self):
return self._url
def __repr__(self):
return "ProvisioningUrl(url={})".format(self.url)
class ProvisioningUrlFactory:
def parse(self, data):
url = data.decode('utf-8')
return ProvisioningUrl(url)
class VendorData(object):
def __init__(self, data):
self._vendor_data = data
@property
def vendor_data(self):
return self._vendor_data
def __repr__(self):
return "Vendor(url={})".format(self.vendor_data)
class VendorDataFactory(object):
def parse(self, data):
return VendorData(data)
class MeshCopCommand(object):
def __init__(self, _type, tlvs):
self._type = _type
self._tlvs = tlvs
@property
def type(self):
return self._type
@property
def tlvs(self):
return self._tlvs
def __repr__(self):
tlvs_str = ", ".join(["{}".format(tlv) for tlv in self.tlvs])
return "MeshCopCommand(type={}, tlvs=[{}])".format(self.type, tlvs_str)
def create_deault_mesh_cop_msg_type_map():
return {
'JOIN_FIN.req': MeshCopMessageType.JOIN_FIN_REQ,
'JOIN_FIN.rsp': MeshCopMessageType.JOIN_FIN_RSP,
'JOIN_ENT.ntf': MeshCopMessageType.JOIN_ENT_NTF,
'JOIN_ENT.rsp': MeshCopMessageType.JOIN_ENT_RSP
}
class MeshCopCommandFactory:
def __init__(self, tlvs_factories):
self._tlvs_factories = tlvs_factories
self._mesh_cop_msg_type_map = create_deault_mesh_cop_msg_type_map()
def _get_length(self, data):
return ord(data.read(1))
def _get_tlv_factory(self, _type):
try:
return self._tlvs_factories[_type]
except KeyError:
raise KeyError("Could not find TLV factory. Unsupported TLV type: {}".format(_type))
def _parse_tlv(self, data):
_type = TlvType(ord(data.read(1)))
length = self._get_length(data)
value = data.read(length)
factory = self._get_tlv_factory(_type)
if factory == None:
return None
return factory.parse(io.BytesIO(value))
def _get_mesh_cop_msg_type(self, msg_type_str):
tp = self._mesh_cop_msg_type_map[msg_type_str]
if tp == None:
raise RuntimeError('Mesh cop message type not found: {}'.format(msg_type_str))
return tp
def parse(self, cmd_type_str, data):
cmd_type = self._get_mesh_cop_msg_type(cmd_type_str)
tlvs = []
while data.tell() < len(data.getvalue()):
tlv = self._parse_tlv(data)
tlvs.append(tlv)
return MeshCopCommand(cmd_type, tlvs)
def create_default_mesh_cop_tlv_factories():
return {
TlvType.STATE: StateFactory(),
TlvType.PROVISIONING_URL: ProvisioningUrlFactory(),
TlvType.VENDOR_NAME: VendorNameFactory(),
TlvType.VENDOR_MODEL: VendorModelFactory(),
TlvType.VENDOR_SW_VERSION: VendorSWVersionFactory(),
TlvType.VENDOR_DATA: VendorDataFactory(),
TlvType.VENDOR_STACK_VERSION: VendorStackVersionFactory()
}
class ThreadDiscoveryTlvsFactory(SubTlvsFactory):
def __init__(self, sub_tlvs_factories):
super(ThreadDiscoveryTlvsFactory, self).__init__(sub_tlvs_factories)
class DiscoveryRequest(object):
def __init__(self, version, joiner_flag):
self._version = version
self._joiner_flag = joiner_flag
@property
def version(self):
return self._version
@property
def joiner_flag(self):
return self._joiner_flag
def __eq__(self, other):
return (type(self) is type(other)
and self.version == other.version
and self.joiner_flag == other.joiner_flag)
def __repr__(self):
return "DiscoveryRequest(version={}, joiner_flag={})".format(
self.version, self.joiner_flag)
class DiscoveryRequestFactory(object):
def parse(self, data, message_info):
data_byte = struct.unpack(">B", data.read(1))[0]
version = (data_byte & 0xf0) >> 4
joiner_flag = (data_byte & 0x08) >> 3
return DiscoveryRequest(version, joiner_flag)
class DiscoveryResponse(object):
def __init__(self, version, native_flag):
self._version = version
self._native_flag = native_flag
@property
def version(self):
return self._version
@property
def native_flag(self):
return self._native_flag
def __eq__(self, other):
return (type(self) is type(other)
and self.version == other.version
and self.native_flag == other.native_flag)
def __repr__(self):
return "DiscoveryResponse(version={}, native_flag={})".format(
self.version, self.native_flag)
class DiscoveryResponseFactory(object):
def parse(self, data, message_info):
data_byte = struct.unpack(">B", data.read(1))[0]
version = (data_byte & 0xf0) >> 4
native_flag = (data_byte & 0x08) >> 3
return DiscoveryResponse(version, native_flag)
class ExtendedPanid(object):
def __init__(self, extended_panid):
self._extended_panid = extended_panid
@property
def extended_panid(self):
return self._extended_panid
def __eq__(self, other):
return (type(self) is type(other)
and self.extended_panid == other.extended_panid)
def __repr__(self):
return "ExtendedPanid(extended_panid={})".format(self.extended_panid)
class ExtendedPanidFactory(object):
def parse(self, data, message_info):
extended_panid = struct.unpack(">Q", data.read(8))[0]
return ExtendedPanid(extended_panid)
class NetworkName(object):
def __init__(self, network_name):
self._network_name = network_name
@property
def network_name(self):
return self._network_name
def __eq__(self, other):
return (type(self) is type(other)
and self.network_name == other.network_name)
def __repr__(self):
return "NetworkName(network_name={})".format(self.network_name)
class NetworkNameFactory(object):
def parse(self, data, message_info):
len = message_info.length
network_name = struct.unpack("{}s".format(10), data.read(len))[0]
return NetworkName(network_name)
class JoinerUdpPort(object):
def __init__(self, udp_port):
self._udp_port = udp_port
@property
def udp_port(self):
return self._udp_port
def __eq__(self, other):
return type(self) is type(other) and self.udp_port == other.udp_port
def __repr__(self):
return "JoinerUdpPort(udp_port={})".format(self.udp_port)
class JoinerUdpPortFactory(object):
def parse(self, data, message_info):
udp_port = struct.unpack(">H", data.read(2))[0]
return JoinerUdpPort(udp_port)
+72 -5
View File
@@ -34,6 +34,7 @@ import sys
import coap
import common
import dtls
import ipv6
import lowpan
import mac802154
@@ -49,6 +50,7 @@ class MessageType(IntEnum):
BEACON = 4
DATA = 5
COMMAND = 6
DTLS = 7
class Message(object):
@@ -61,6 +63,7 @@ class Message(object):
self._coap = None
self._mle = None
self._icmp = None
self._dtls = None
def _extract_udp_datagram(self, udp_datagram):
if isinstance(udp_datagram.payload, mle.MleMessage):
@@ -71,6 +74,11 @@ class Message(object):
self._type = MessageType.COAP
self._coap = udp_datagram.payload
# DTLS message factory returns a list of messages
elif isinstance(udp_datagram.payload, list):
self._type = MessageType.DTLS
self._dtls = udp_datagram.payload
def _extract_upper_layer_protocol(self, upper_layer_protocol):
if isinstance(upper_layer_protocol, ipv6.ICMPv6):
self._type = MessageType.ICMP
@@ -79,6 +87,32 @@ class Message(object):
elif isinstance(upper_layer_protocol, ipv6.UDPDatagram):
self._extract_udp_datagram(upper_layer_protocol)
def try_extract_dtls_messages(self):
"""Extract multiple dtls messages that are sent in a single UDP datagram
"""
if self.type != MessageType.DTLS:
return [self.clone()]
assert isinstance(self.dtls, list)
ret = []
for dtls in self.dtls:
msg = self.clone()
msg._dtls = dtls
ret.append(msg)
return ret
def clone(self):
msg = Message()
msg._type = self.type
msg._channel = self.channel
msg._mac_header = self.mac_header
msg._ipv6_packet = self.ipv6_packet
msg._coap = self.coap
msg._mle = self.mle
msg._icmp = self.icmp
msg._dtls = self.dtls
return msg
@property
def type(self):
return self._type
@@ -141,6 +175,10 @@ class Message(object):
def icmp(self, value):
self._icmp = value
@property
def dtls(self):
return self._dtls
def get_mle_message_tlv(self, tlv_class_type):
if self.type != MessageType.MLE:
raise ValueError("Invalid message type. Expected MLE message.")
@@ -306,19 +344,30 @@ class Message(object):
def isMacAddressTypeLong(self):
return self.mac_header.dest_address.type == common.MacAddressType.LONG
def get_dst_udp_port(self):
assert isinstance(self.ipv6_packet.upper_layer_protocol, ipv6.UDPDatagram)
return self.ipv6_packet.upper_layer_protocol.header.dst_port
def __repr__(self):
if self.type == MessageType.DTLS and self.dtls.content_type == dtls.ContentType.HANDSHAKE:
return "Message(type={})".format(str(self.dtls.handshake_type))
return "Message(type={})".format(MessageType(self.type).name)
class MessagesSet(object):
def __init__(self, messages):
def __init__(self, messages, commissioning_messages=[]):
self._messages = messages
self._commissioning_messages = commissioning_messages
@property
def messages(self):
return self._messages
@property
def commissioning_messages(self):
return self._commissioning_messages
def next_coap_message(self, code, uri_path=None, assert_enabled=True):
message = None
@@ -433,6 +482,21 @@ class MessagesSet(object):
def next_command_message(self):
return self.next_message_of(MessageType.COMMAND)
def next_dtls_message(self, content_type, handshake_type=None):
while self.messages:
msg = self.messages.pop(0)
if msg.type != MessageType.DTLS:
continue
if msg.dtls.content_type != content_type:
continue
if (content_type == dtls.ContentType.HANDSHAKE and
msg.dtls.handshake_type != handshake_type):
continue
return msg
t = handshake_type if content_type == dtls.ContentType.HANDSHAKE else content_type
raise ValueError("Could not find DTLS message of type: {}".format(str(t)))
def contains_icmp_message(self):
for m in self.messages:
if m.type == MessageType.ICMP:
@@ -472,7 +536,10 @@ class MessagesSet(object):
def clone(self):
"""Make a copy of current MessageSet.
"""
return MessagesSet(self.messages[:])
return MessagesSet(self.messages[:], self.commissioning_messages[:])
def __repr__(self):
return str(self.messages)
class MessageFactory:
@@ -506,7 +573,7 @@ class MessageFactory:
message.mac_header = mac_frame.header
if message.mac_header.frame_type != mac802154.MacHeader.FrameType.DATA:
return message
return [message]
message_info = common.MessageInfo()
message_info.source_mac_address = message.mac_header.src_address
@@ -517,11 +584,11 @@ class MessageFactory:
ipv6_packet = self._lowpan_parser.parse(lowpan_payload, message_info)
if ipv6_packet is None:
return message
return [message]
message.ipv6_packet = ipv6_packet
if message.type == MessageType.MLE:
self._add_device_descriptors(message)
return message
return message.try_extract_dtls_messages()
+19 -5
View File
@@ -1009,17 +1009,30 @@ class PendingOperationalDatasetFactory:
return PendingOperationalDataset()
class ThreadDiscovery:
# TODO: Not implemented yet
class ThreadDiscovery(object):
def __init__(self):
print("ThreadDiscovery is not implemented yet.")
def __init__(self, tlvs):
self._tlvs = tlvs
@property
def tlvs(self):
return self._tlvs
def __eq__(self, other):
return self.tlvs == other.tlvs
def __repr__(self):
return "ThreadDiscovery(tlvs={})".format(self.tlvs)
class ThreadDiscoveryFactory:
def __init__(self, thread_discovery_tlvs_factory):
self._tlvs_factory = thread_discovery_tlvs_factory
def parse(self, data, message_info):
return ThreadDiscovery()
tlvs = self._tlvs_factory.parse(data, message_info)
return ThreadDiscovery(tlvs)
class TimeRequest:
# TODO: Not implemented yet
@@ -1044,6 +1057,7 @@ class TimeParameterFactory:
def parse(self, data, message_info):
return TimeParameter()
class MleCommand(object):
def __init__(self, _type, tlvs):
+64 -12
View File
@@ -45,6 +45,8 @@ class otCli:
self.verbose = int(float(os.getenv('VERBOSE', 0)))
self.node_type = os.getenv('NODE_TYPE', 'sim')
self.simulator = simulator
if self.simulator:
self.simulator.add_node(self)
mode = os.environ.get('USE_MTD') is '1' and is_mtd and 'mtd' or 'ftd'
@@ -152,6 +154,56 @@ class otCli:
self.pexpect.expect(pexpect.EOF)
self.pexpect = None
def read_cert_messages_in_commissioning_log(self, timeout=-1):
"""Get the log of the traffic after DTLS handshake.
"""
format_str = br"=+?\[\[THCI\].*?type=%s.*?\].*?=+?[\s\S]+?-{40,}"
join_fin_req = format_str % br"JOIN_FIN\.req"
join_fin_rsp = format_str % br"JOIN_FIN\.rsp"
dummy_format_str = br"\[THCI\].*?type=%s.*?"
join_ent_ntf = dummy_format_str % br"JOIN_ENT\.ntf"
join_ent_rsp = dummy_format_str % br"JOIN_ENT\.rsp"
pattern = (b"(" + join_fin_req + b")|(" + join_fin_rsp + b")|("+ join_ent_ntf + b")|(" + join_ent_rsp + b")")
messages = []
# There are at most 4 cert messages both for joiner and commissioner
for _ in range(0, 4):
try:
self._expect(pattern, timeout=timeout)
log = self.pexpect.match.group(0)
messages.append(self._extract_cert_message(log))
except:
break
return messages
def _extract_cert_message(self, log):
res = re.search(br"direction=\w+", log)
assert res
direction = res.group(0).split(b'=')[1].strip()
res = re.search(br"type=\S+", log)
assert res
type = res.group(0).split(b'=')[1].strip()
payload = bytearray([])
payload_len = 0
if type in [b"JOIN_FIN.req", b"JOIN_FIN.rsp"]:
res = re.search(br"len=\d+", log)
assert res
payload_len = int(res.group(0).split(b'=')[1].strip())
hex_pattern = br"\|(\s([0-9a-fA-F]{2}|\.\.))+?\s+?\|"
while True:
res = re.search(hex_pattern, log)
if not res:
break
data = [int(hex, 16) for hex in res.group(0)[1:-1].split(b' ') if hex and hex != b'..']
payload += bytearray(data)
log = log[res.end()-1:]
assert len(payload) == payload_len
return (direction, type, payload)
def send_command(self, cmd, go=True):
print("%d: %s" % (self.nodeid, cmd))
self.pexpect.send(cmd + '\n')
@@ -164,7 +216,7 @@ class otCli:
self._expect('Commands:')
commands = []
while True:
i = self._expect(['Done', '(\S+)'])
i = self._expect(['Done', r'(\S+)'])
if i != 0:
commands.append(self.pexpect.match.groups()[0])
else:
@@ -280,7 +332,7 @@ class otCli:
def get_channel(self):
self.send_command('channel')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
channel = int(self.pexpect.match.groups()[0])
self._expect('Done')
@@ -306,7 +358,7 @@ class otCli:
def get_key_sequence_counter(self):
self.send_command('keysequence counter')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
key_sequence_counter = int(self.pexpect.match.groups()[0])
self._expect('Done')
@@ -330,7 +382,7 @@ class otCli:
def get_network_name(self):
self.send_command('networkname')
while True:
i = self._expect(['Done', '(\S+)'])
i = self._expect(['Done', r'(\S+)'])
if i != 0:
network_name = self.pexpect.match.groups()[0].decode('utf-8')
else:
@@ -357,7 +409,7 @@ class otCli:
def get_partition_id(self):
self.send_command('leaderpartitionid')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
weight = self.pexpect.match.groups()[0]
self._expect('Done')
@@ -397,7 +449,7 @@ class otCli:
def get_timeout(self):
self.send_command('childtimeout')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
timeout = self.pexpect.match.groups()[0]
self._expect('Done')
@@ -415,7 +467,7 @@ class otCli:
def get_weight(self):
self.send_command('leaderweight')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
weight = self.pexpect.match.groups()[0]
self._expect('Done')
@@ -436,7 +488,7 @@ class otCli:
self.send_command('ipaddr')
while True:
i = self._expect(['(\S+(:\S*)+)\r?\n', 'Done'])
i = self._expect([r'(\S+(:\S*)+)\r?\n', 'Done'])
if i == 0:
addrs.append(self.pexpect.match.groups()[0].decode("utf-8"))
elif i == 1:
@@ -464,7 +516,7 @@ class otCli:
self.send_command('eidcache')
while True:
i = self._expect(['([a-fA-F0-9\:]+) ([a-fA-F0-9]+)\r?\n', 'Done'])
i = self._expect([r'([a-fA-F0-9\:]+) ([a-fA-F0-9]+)\r?\n', 'Done'])
if i == 0:
eid = self.pexpect.match.groups()[0].decode("utf-8")
rloc = self.pexpect.match.groups()[1].decode("utf-8")
@@ -553,7 +605,7 @@ class otCli:
def get_context_reuse_delay(self):
self.send_command('contextreusedelay')
i = self._expect('(\d+)\r?\n')
i = self._expect(r'(\d+)\r?\n')
if i == 0:
timeout = self.pexpect.match.groups()[0]
self._expect('Done')
@@ -617,7 +669,7 @@ class otCli:
results = []
while True:
i = self._expect(['\|\s(\S+)\s+\|\s(\S+)\s+\|\s([0-9a-fA-F]{4})\s\|\s([0-9a-fA-F]{16})\s\|\s(\d+)\r?\n',
i = self._expect([r'\|\s(\S+)\s+\|\s(\S+)\s+\|\s([0-9a-fA-F]{4})\s\|\s([0-9a-fA-F]{16})\s\|\s(\d+)\r?\n',
'Done'])
if i == 0:
results.append(self.pexpect.match.groups())
@@ -640,7 +692,7 @@ class otCli:
try:
responders = {}
while len(responders) < num_responses:
i = self._expect(['from (\S+):'])
i = self._expect([r'from (\S+):'])
if i == 0:
responders[self.pexpect.match.groups()[0]] = 1
self._expect('\n')
+1
View File
@@ -76,6 +76,7 @@ class PcapCodec(object):
timestamp = self._get_timestamp()
pkt = self.encode_frame(frame, *timestamp)
self._pcap_file.write(pkt)
self._pcap_file.flush()
def __del__(self):
self._pcap_file.close()
+53 -8
View File
@@ -39,6 +39,8 @@ import time
import io
import config
import dtls
import mesh_cop
import message
import pcap
@@ -46,9 +48,47 @@ def dbg_print(*args):
if False:
print(args)
class RealTime:
class BaseSimulator(object):
def __init__(self):
self._nodes = {}
self.commissioning_messages = {}
self._payload_parse_factory = mesh_cop.MeshCopCommandFactory(mesh_cop.create_default_mesh_cop_tlv_factories())
self._mesh_cop_msg_set = mesh_cop.create_mesh_cop_message_type_set()
def __del__(self):
self._nodes = None
def add_node(self, node):
self._nodes[node.nodeid] = node
self.commissioning_messages[node.nodeid] = []
def set_lowpan_context(self, cid, prefix):
raise NotImplementedError
def get_messages_sent_by(self, nodeid):
raise NotImplementedError
def go(self, duration, nodeid=None):
raise NotImplementedError
def stop(self):
raise NotImplementedError
def read_cert_messages_in_commissioning_log(self, nodeids):
for nodeid in nodeids:
node = self._nodes[nodeid]
if not node:
continue
for direction, type, payload in node.read_cert_messages_in_commissioning_log():
if direction == b'send':
msg = self._payload_parse_factory.parse(type.decode("utf-8"), io.BytesIO(payload))
self.commissioning_messages[nodeid].append(msg)
class RealTime(BaseSimulator):
def __init__(self):
super(RealTime, self).__init__()
self._sniffer = config.create_default_thread_sniffer()
self._sniffer.start()
@@ -56,7 +96,10 @@ class RealTime:
self._sniffer.set_lowpan_context(cid, prefix)
def get_messages_sent_by(self, nodeid):
return self._sniffer.get_messages_sent_by(nodeid)
messages = self._sniffer.get_messages_sent_by(nodeid).messages
ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
self.commissioning_messages[nodeid] = []
return ret
def go(self, duration, nodeid=None):
time.sleep(duration)
@@ -64,7 +107,7 @@ class RealTime:
def stop(self):
pass
class VirtualTime:
class VirtualTime(BaseSimulator):
OT_SIM_EVENT_ALARM_FIRED = 0
OT_SIM_EVENT_RADIO_RECEIVED = 1
@@ -91,6 +134,7 @@ class VirtualTime:
NCP_SIM = os.getenv('NODE_TYPE', 'sim') == 'ncp-sim'
def __init__(self):
super(VirtualTime, self).__init__()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip = '127.0.0.1'
@@ -126,14 +170,13 @@ class VirtualTime:
# Ignore any exceptions
try:
msg = self._message_factory.create(io.BytesIO(message))
if msg is not None:
self.devices[addr]['msgs'].append(msg)
messages = self._message_factory.create(io.BytesIO(message))
self.devices[addr]['msgs'] += messages
except Exception as e:
# Just print the exception to the console
print("EXCEPTION: %s" % e)
traceback.print_exc()
def set_lowpan_context(self, cid, prefix):
self._message_factory.set_lowpan_context(cid, prefix)
@@ -154,7 +197,9 @@ class VirtualTime:
messages = self.devices[addr]['msgs']
self.devices[addr]['msgs'] = []
return message.MessagesSet(messages)
ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
self.commissioning_messages[nodeid] = []
return ret
def _is_radio(self, addr):
return addr[1] < self.BASE_PORT * 2
+5 -5
View File
@@ -33,6 +33,7 @@ import logging
import os
import pcap
import threading
import traceback
try:
import Queue
@@ -85,16 +86,15 @@ class Sniffer:
# Ignore any exceptions
try:
msg = self._message_factory.create(io.BytesIO(data))
if msg is not None:
self.logger.debug("Received message: {}".format(msg))
messages = self._message_factory.create(io.BytesIO(data))
self.logger.debug("Received messages: {}".format(messages))
for msg in messages:
self._buckets[nodeid].put(msg)
except Exception as e:
# Just print the exception to the console
print("EXCEPTION: %s" % e)
pass
traceback.print_exc()
self.logger.debug("Sniffer stopped.")