#!/usr/bin/env python # # Copyright (c) 2017-2018, The OpenThread Authors. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # import binascii import sys import ipv6 import network_data import network_layer import common import config import mesh_cop import mle from collections import Counter from enum import IntEnum from network_data import Prefix, BorderRouter, LowpanId class CheckType(IntEnum): CONTAIN = 0 NOT_CONTAIN = 1 OPTIONAL = 2 class NetworkDataCheckType: PREFIX_CNT = 1 PREFIX_CONTENT = 2 def check_address_query(command_msg, source_node, destination_address): """Verify source_node sent a properly formatted Address Query Request message to the destination_address. """ command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid) source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC) assert ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address, \ "Error: The IPv6 source address is not the RLOC of the originator. The source node's rloc is: " \ + str(ipv6.ip_address(source_rloc)) + ", but the source_address in command msg is: " \ + str(command_msg.ipv6_packet.ipv6_header.source_address) if isinstance(destination_address, bytearray): destination_address = bytes(destination_address) elif isinstance(destination_address, str) and sys.version_info[0] == 2: destination_address = destination_address.decode("utf-8") assert ipv6.ip_address(destination_address) == command_msg.ipv6_packet.ipv6_header.destination_address, "Error: The IPv6 destination address is not expected." def check_address_notification(command_msg, source_node, destination_node): """Verify source_node sent a properly formatted Address Notification command message to destination_node. """ command_msg.assertCoapMessageRequestUriPath('/a/an') command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid) command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16) command_msg.assertCoapMessageContainsTlv(network_layer.MlEid) source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC) assert ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address, "Error: The IPv6 source address is not the RLOC of the originator." destination_rloc = destination_node.get_ip6_address(config.ADDRESS_TYPE.RLOC) assert ipv6.ip_address(destination_rloc) == command_msg.ipv6_packet.ipv6_header.destination_address, "Error: The IPv6 destination address is not the RLOC of the destination." def check_address_error_notification(command_msg, source_node, destination_address): """Verify source_node sent a properly formatted Address Error Notification command message to destination_address. """ command_msg.assertCoapMessageRequestUriPath('/a/ae') command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid) command_msg.assertCoapMessageContainsTlv(network_layer.MlEid) source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC) assert ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address, \ "Error: The IPv6 source address is not the RLOC of the originator. The source node's rloc is: " \ + str(ipv6.ip_address(source_rloc)) + ", but the source_address in command msg is: " \ + str(command_msg.ipv6_packet.ipv6_header.source_address) if isinstance(destination_address, bytearray): destination_address = bytes(destination_address) elif isinstance(destination_address, str) and sys.version_info[0] == 2: destination_address = destination_address.decode("utf-8") assert ipv6.ip_address(destination_address) == command_msg.ipv6_packet.ipv6_header.destination_address, \ "Error: The IPv6 destination address is not expected. The destination node's rloc is: " \ + str(ipv6.ip_address(destination_address)) + ", but the destination_address in command msg is: " \ + str(command_msg.ipv6_packet.ipv6_header.destination_address) def check_address_solicit(command_msg, was_router): command_msg.assertCoapMessageRequestUriPath('/a/as') command_msg.assertCoapMessageContainsTlv(network_layer.MacExtendedAddress) command_msg.assertCoapMessageContainsTlv(network_layer.Status) if was_router: command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16) else: command_msg.assertMleMessageDoesNotContainTlv(network_layer.Rloc16) def check_address_release(command_msg, destination_node): """Verify the message is a properly formatted address release destined to the given node. """ command_msg.assertCoapMessageRequestUriPath('/a/ar') command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16) command_msg.assertCoapMessageContainsTlv(network_layer.MacExtendedAddress) destination_rloc = destination_node.get_ip6_address(config.ADDRESS_TYPE.RLOC) assert ipv6.ip_address(destination_rloc) == command_msg.ipv6_packet.ipv6_header.destination_address, "Error: The destination is not RLOC address" def check_tlv_request_tlv(command_msg, check_type, tlv_id): """Verify if TLV Request TLV contains specified TLV ID """ tlv_request_tlv = command_msg.get_mle_message_tlv(mle.TlvRequest) if check_type == CheckType.CONTAIN: assert tlv_request_tlv is not None, "Error: The msg doesn't contain TLV Request TLV" assert any(tlv_id == tlv for tlv in tlv_request_tlv.tlvs), "Error: The msg doesn't contain TLV Request TLV ID: {}".format(tlv_id) elif check_type == CheckType.NOT_CONTAIN: if tlv_request_tlv is not None: assert any(tlv_id == tlv for tlv in tlv_request_tlv.tlvs) is False, "Error: The msg contains TLV Request TLV ID: {}".format(tlv_id) elif check_type == CheckType.OPTIONAL: if tlv_request_tlv is not None: if any(tlv_id == tlv for tlv in tlv_request_tlv.tlvs): print("TLV Request TLV contains TLV ID: {}".format(tlv_id)) else: print("TLV Request TLV doesn't contain TLV ID: {}".format(tlv_id)) else: print("The msg doesn't contain TLV Request TLV") else: raise ValueError("Invalid check type") def check_link_request(command_msg, source_address = CheckType.OPTIONAL, leader_data = CheckType.OPTIONAL, \ tlv_request_address16 = CheckType.OPTIONAL, tlv_request_route64 = CheckType.OPTIONAL, \ tlv_request_link_margin = CheckType.OPTIONAL): """Verify a properly formatted Link Request command message. """ command_msg.assertMleMessageContainsTlv(mle.Challenge) command_msg.assertMleMessageContainsTlv(mle.Version) check_mle_optional_tlv(command_msg, source_address, mle.SourceAddress) check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData) check_tlv_request_tlv(command_msg, tlv_request_address16, mle.TlvType.ADDRESS16) check_tlv_request_tlv(command_msg, tlv_request_route64, mle.TlvType.ROUTE64) check_tlv_request_tlv(command_msg, tlv_request_link_margin, mle.TlvType.LINK_MARGIN) def check_link_accept(command_msg, destination_node, \ leader_data = CheckType.OPTIONAL, link_margin = CheckType.OPTIONAL, mle_frame_counter = CheckType.OPTIONAL, \ challenge = CheckType.OPTIONAL, address16 = CheckType.OPTIONAL, route64 = CheckType.OPTIONAL, \ tlv_request_link_margin = CheckType.OPTIONAL): """verify a properly formatted link accept command message. """ command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.Response) command_msg.assertMleMessageContainsTlv(mle.Version) check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData) check_mle_optional_tlv(command_msg, link_margin, mle.LinkMargin) check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter) check_mle_optional_tlv(command_msg, challenge, mle.Challenge) check_mle_optional_tlv(command_msg, address16, mle.Address16) check_mle_optional_tlv(command_msg, route64, mle.Route64) check_tlv_request_tlv(command_msg, tlv_request_link_margin, mle.TlvType.LINK_MARGIN) destination_link_local = destination_node.get_ip6_address(config.ADDRESS_TYPE.LINK_LOCAL) assert ipv6.ip_address(destination_link_local) == command_msg.ipv6_packet.ipv6_header.destination_address, \ "Error: The destination is unexpected" def check_icmp_path(sniffer, path, nodes, icmp_type = ipv6.ICMP_ECHO_REQUEST): """Verify icmp message is forwarded along the path. """ len_path = len(path) # Verify icmp message is forwarded to the next node of the path. for i in range(0, len_path): node_msg = sniffer.get_messages_sent_by(path[i]) node_icmp_msg = node_msg.get_icmp_message(icmp_type) if i < len_path - 1: next_node = nodes[path[i + 1]] next_node_rloc16 = next_node.get_addr16() assert next_node_rloc16 == node_icmp_msg.mac_header.dest_address.rloc, "Error: The path is unexpected." else: return True return False def check_id_set(command_msg, router_id): """Check the command_msg's Route64 tlv to verify router_id is an active router. """ tlv = command_msg.assertMleMessageContainsTlv(mle.Route64) return ((tlv.router_id_mask >> (63 - router_id)) & 1) def get_routing_cost(command_msg, router_id): """Check the command_msg's Route64 tlv to get the routing cost to router. """ tlv = command_msg.assertMleMessageContainsTlv(mle.Route64) # Get router's mask pos # Turn the number into binary string. Need to consider the preceding 0 omitted during conversion. router_id_mask_str = bin(tlv.router_id_mask).replace('0b','') prefix_len = 64 - len(router_id_mask_str) routing_entry_pos = 0 for i in range(0, router_id - prefix_len): if router_id_mask_str[i] == '1': routing_entry_pos += 1 assert router_id_mask_str[router_id - prefix_len] == '1', "Error: The router isn't in the topology. \n" \ + "route64 tlv is: %s. \nrouter_id is: %s. \nrouting_entry_pos is: %s. \nrouter_id_mask_str is: %s." \ %(tlv, router_id, routing_entry_pos, router_id_mask_str) return tlv.link_quality_and_route_data[routing_entry_pos].route def check_mle_optional_tlv(command_msg, type, tlv): if (type == CheckType.CONTAIN): command_msg.assertMleMessageContainsTlv(tlv) elif (type == CheckType.NOT_CONTAIN): command_msg.assertMleMessageDoesNotContainTlv(tlv) elif (type == CheckType.OPTIONAL): command_msg.assertMleMessageContainsOptionalTlv(tlv) else: raise ValueError("Invalid check type") def check_mle_advertisement(command_msg): command_msg.assertSentWithHopLimit(255) command_msg.assertSentToDestinationAddress(config.LINK_LOCAL_ALL_NODES_ADDRESS) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.LeaderData) command_msg.assertMleMessageContainsTlv(mle.Route64) def check_parent_request(command_msg, is_first_request): """Verify a properly formatted Parent Request command message. """ if command_msg.mle.aux_sec_hdr.key_id_mode != 0x2: raise ValueError("The Key Identifier Mode of the Security Control Field SHALL be set to 0x02") command_msg.assertSentWithHopLimit(255) command_msg.assertSentToDestinationAddress(config.LINK_LOCAL_ALL_ROUTERS_ADDRESS) command_msg.assertMleMessageContainsTlv(mle.Mode) command_msg.assertMleMessageContainsTlv(mle.Challenge) command_msg.assertMleMessageContainsTlv(mle.Version) scan_mask = command_msg.assertMleMessageContainsTlv(mle.ScanMask) if not scan_mask.router: raise ValueError("Parent request without R bit set") if is_first_request: if scan_mask.end_device: raise ValueError("First parent request with E bit set") elif not scan_mask.end_device: raise ValueError("Second parent request without E bit set") def check_parent_response(command_msg, mle_frame_counter = CheckType.OPTIONAL): """Verify a properly formatted Parent Response command message. """ command_msg.assertMleMessageContainsTlv(mle.Challenge) command_msg.assertMleMessageContainsTlv(mle.Connectivity) command_msg.assertMleMessageContainsTlv(mle.LeaderData) command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter) command_msg.assertMleMessageContainsTlv(mle.LinkMargin) command_msg.assertMleMessageContainsTlv(mle.Response) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.Version) check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter) def check_child_id_request(command_msg, tlv_request = CheckType.OPTIONAL, \ mle_frame_counter = CheckType.OPTIONAL, address_registration = CheckType.OPTIONAL, \ active_timestamp = CheckType.OPTIONAL, pending_timestamp = CheckType.OPTIONAL, route64 = CheckType.OPTIONAL): """Verify a properly formatted Child Id Request command message. """ if command_msg.mle.aux_sec_hdr.key_id_mode != 0x2: raise ValueError("The Key Identifier Mode of the Security Control Field SHALL be set to 0x02") command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter) command_msg.assertMleMessageContainsTlv(mle.Mode) command_msg.assertMleMessageContainsTlv(mle.Response) command_msg.assertMleMessageContainsTlv(mle.Timeout) command_msg.assertMleMessageContainsTlv(mle.Version) check_mle_optional_tlv(command_msg, tlv_request, mle.TlvRequest) check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter) check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration) check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp) check_mle_optional_tlv(command_msg, pending_timestamp, mle.PendingTimestamp) check_mle_optional_tlv(command_msg, route64, mle.Route64) check_tlv_request_tlv(command_msg, CheckType.CONTAIN, mle.TlvType.ADDRESS16) check_tlv_request_tlv(command_msg, CheckType.CONTAIN, mle.TlvType.NETWORK_DATA) def check_child_id_response(command_msg, route64 = CheckType.OPTIONAL, network_data = CheckType.OPTIONAL, \ address_registration = CheckType.OPTIONAL, active_timestamp = CheckType.OPTIONAL, \ pending_timestamp = CheckType.OPTIONAL, active_operational_dataset = CheckType.OPTIONAL, \ pending_operational_dataset = CheckType.OPTIONAL, network_data_check = None): """Verify a properly formatted Child Id Response command message. """ command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.LeaderData) command_msg.assertMleMessageContainsTlv(mle.Address16) check_mle_optional_tlv(command_msg, route64, mle.Route64) check_mle_optional_tlv(command_msg, network_data, mle.NetworkData) check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration) check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp) check_mle_optional_tlv(command_msg, pending_timestamp, mle.PendingTimestamp) check_mle_optional_tlv(command_msg, active_operational_dataset, mle.ActiveOperationalDataset) check_mle_optional_tlv(command_msg, pending_operational_dataset, mle.PendingOperationalDataset) if network_data_check is not None: network_data_tlv = command_msg.assertMleMessageContainsTlv(mle.NetworkData) network_data_check.check(network_data_tlv) def check_prefix(prefix): """Verify if a prefix contains 6loWPAN sub-TLV and border router sub-TLV """ assert contains_tlv(prefix.sub_tlvs, network_data.BorderRouter), 'Prefix doesn\'t contain a border router sub-TLV!' assert contains_tlv(prefix.sub_tlvs, network_data.LowpanId), 'Prefix doesn\'t contain a LowpanId sub-TLV!' def check_child_update_request_from_child(command_msg, source_address=CheckType.OPTIONAL, leader_data=CheckType.OPTIONAL, challenge=CheckType.OPTIONAL, time_out=CheckType.OPTIONAL, address_registration=CheckType.OPTIONAL, tlv_request_tlv=CheckType.OPTIONAL, active_timestamp=CheckType.OPTIONAL, CIDs=[]): command_msg.assertMleMessageContainsTlv(mle.Mode) check_mle_optional_tlv(command_msg, source_address, mle.SourceAddress) check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData) check_mle_optional_tlv(command_msg, challenge, mle.Challenge) check_mle_optional_tlv(command_msg, time_out, mle.Timeout) check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration) check_mle_optional_tlv(command_msg, tlv_request_tlv, mle.TlvRequest) check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp) if (address_registration == CheckType.CONTAIN) and len(CIDs) > 0: _check_address_registration(command_msg, CIDs) def check_coap_optional_tlv(coap_msg, type, tlv): if (type == CheckType.CONTAIN): coap_msg.assertCoapMessageContainsTlv(tlv) elif (type == CheckType.NOT_CONTAIN): coap_msg.assertCoapMessageDoesNotContainTlv(tlv) elif (type == CheckType.OPTIONAL): coap_msg.assertCoapMessageContainsOptionalTlv(tlv) else: raise ValueError("Invalid check type") def check_router_id_cached(node, router_id, cached = True): """Verify if the node has cached any entries based on the router ID """ eidcaches = node.get_eidcaches() if cached: assert any(router_id == (int(rloc, 16) >> 10) for (_, rloc) in eidcaches) else: assert any(router_id == (int(rloc, 16) >> 10) for (_, rloc) in eidcaches) is False def contains_tlv(sub_tlvs, tlv_type): """Verify if a specific type of tlv is included in a sub-tlv list. """ return any(isinstance(sub_tlv, tlv_type) for sub_tlv in sub_tlvs) def contains_tlvs(sub_tlvs, tlv_types): """Verify if all types of tlv in a list are included in a sub-tlv list. """ return all((any(isinstance(sub_tlv, tlv_type) for sub_tlv in sub_tlvs)) for tlv_type in tlv_types) def check_secure_mle_key_id_mode(command_msg, key_id_mode): """Verify if the mle command message sets the right key id mode. """ assert isinstance(command_msg.mle, mle.MleMessageSecured) assert command_msg.mle.aux_sec_hdr.key_id_mode == key_id_mode def check_data_response(command_msg, network_data_check=None, active_timestamp=CheckType.OPTIONAL): """Verify a properly formatted Data Response command message. """ check_secure_mle_key_id_mode(command_msg, 0x02) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.LeaderData) check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp) if network_data_check is not None: network_data_tlv = command_msg.assertMleMessageContainsTlv(mle.NetworkData) network_data_check.check(network_data_tlv) def check_child_update_request_from_parent(command_msg, leader_data=CheckType.OPTIONAL, network_data=CheckType.OPTIONAL, challenge=CheckType.OPTIONAL, tlv_request=CheckType.OPTIONAL, active_timestamp=CheckType.OPTIONAL): """Verify a properly formatted Child Update Request(from parent) command message. """ check_secure_mle_key_id_mode(command_msg, 0x02) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData) check_mle_optional_tlv(command_msg, network_data, mle.NetworkData) check_mle_optional_tlv(command_msg, challenge, mle.Challenge) check_mle_optional_tlv(command_msg, tlv_request, mle.TlvRequest) check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp) def check_child_update_response(command_msg, timeout=CheckType.OPTIONAL, address_registration=CheckType.OPTIONAL, address16=CheckType.OPTIONAL, leader_data=CheckType.OPTIONAL, network_data=CheckType.OPTIONAL, response=CheckType.OPTIONAL, link_layer_frame_counter=CheckType.OPTIONAL, mle_frame_counter=CheckType.OPTIONAL, CIDs=[]): """Verify a properly formatted Child Update Response from parent """ check_secure_mle_key_id_mode(command_msg, 0x02) command_msg.assertMleMessageContainsTlv(mle.SourceAddress) command_msg.assertMleMessageContainsTlv(mle.Mode) check_mle_optional_tlv(command_msg, timeout, mle.Timeout) check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration) check_mle_optional_tlv(command_msg, address16, mle.Address16) check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData) check_mle_optional_tlv(command_msg, network_data, mle.NetworkData) check_mle_optional_tlv(command_msg, response, mle.Response) check_mle_optional_tlv(command_msg, link_layer_frame_counter, mle.LinkLayerFrameCounter) check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter) if (address_registration == CheckType.CONTAIN) and len(CIDs) > 0: _check_address_registration(command_msg, CIDs) def _check_address_registration(command_msg, CIDs=[]): addresses = command_msg.assertMleMessageContainsTlv(mle.AddressRegistration).addresses for cid in CIDs: found = False for address in addresses: if isinstance(address, mle.AddressCompressed): if cid == address.cid: found = True break assert found, "AddressRegistration TLV doesn't have CID {} ".format(cid) def get_sub_tlv(tlvs, tlv_type): for sub_tlv in tlvs: if isinstance(sub_tlv, tlv_type): return sub_tlv def check_address_registration_tlv(addr_reg_tlv, address_set): """Verify all addresses contained in address_set are contained in add_reg_tlv """ assert all(addr in addr_reg_tlv.addresses for addr in address_set), 'Some addresses are not included in AddressRegistration TLV' def assert_contains_tlv(tlvs, check_type, tlv_type): """Assert a tlv list contains specific tlv and return the first qualified. """ tlvs = [tlv for tlv in tlvs if isinstance(tlv, tlv_type)] if check_type is CheckType.CONTAIN: assert tlvs return tlvs[0] elif check_type is CheckType.NOT_CONTAIN: assert not tlvs return None elif check_type is CheckType.OPTIONAL: return None else: raise ValueError("Invalid check type: {}".format(check_type)) def check_discovery_request(command_msg): """Verify a properly formatted Thread Discovery Request command message. """ assert not isinstance(command_msg.mle, mle.MleMessageSecured) tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs request = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryRequest) assert request.version == config.PROTOCOL_VERSION def check_discovery_response(command_msg, request_src_addr, steering_data=CheckType.OPTIONAL): """Verify a properly formatted Thread Discovery Response command message. """ assert not isinstance(command_msg.mle, mle.MleMessageSecured) assert command_msg.mac_header.src_address.type == common.MacAddressType.LONG assert command_msg.mac_header.dest_address == request_src_addr tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs response = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryResponse) assert response.version == config.PROTOCOL_VERSION assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.ExtendedPanid) assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.NetworkName) assert_contains_tlv(tlvs, steering_data, mesh_cop.SteeringData) assert_contains_tlv(tlvs, steering_data, mesh_cop.JoinerUdpPort) check_type = CheckType.CONTAIN if response.native_flag else CheckType.OPTIONAL assert_contains_tlv(tlvs, check_type, mesh_cop.CommissionerUdpPort) def get_joiner_udp_port_in_discovery_response(command_msg): """Get the udp port specified in a DISCOVERY RESPONSE message """ tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs udp_port_tlv = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.JoinerUdpPort) return udp_port_tlv.udp_port def check_joiner_commissioning_messages(commissioning_messages): """Verify COAP messages sent by joiner while commissioning process. """ print(commissioning_messages) assert len(commissioning_messages) >= 2 join_fin_req = commissioning_messages[0] assert join_fin_req.type == mesh_cop.MeshCopMessageType.JOIN_FIN_REQ assert_contains_tlv(join_fin_req.tlvs, CheckType.NOT_CONTAIN, mesh_cop.ProvisioningUrl) join_ent_rsp = commissioning_messages[1] assert join_ent_rsp.type == mesh_cop.MeshCopMessageType.JOIN_ENT_RSP def check_commissioner_commissioning_messages(commissioning_messages): """Verify COAP messages sent by commissioner while commissioning process. """ assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_FIN_RSP for msg in commissioning_messages) def check_joiner_router_commissioning_messages(commissioning_messages): """Verify COAP messages sent by joiner router while commissioning process. """ assert any(msg.type == mesh_cop.MeshCopMessageType.JOIN_ENT_NTF for msg in commissioning_messages) return None def check_payload_same(tp1, tp2): """Verfiy two payloads are totally the same. A payload is a tuple of tlvs. """ assert len(tp1) == len(tp2) for tlv in tp2: peer_tlv = get_sub_tlv(tp1, type(tlv)) assert peer_tlv is not None and peer_tlv == tlv, 'peer_tlv:{}, tlv:{} type:{}'.format(peer_tlv, tlv, type(tlv)) def check_coap_message(msg, payloads, dest_addrs=None): if dest_addrs is not None: found = False for dest in dest_addrs: if msg.ipv6_packet.ipv6_header.destination_address == dest: found = True break assert found, 'Destination address incorrect' check_payload_same(msg.coap.payload, payloads) class SinglePrefixCheck: def __init__(self, prefix=None, border_router_16=None): self._prefix = prefix self._border_router_16 = border_router_16 def check(self, prefix_tlv): border_router_tlv = assert_contains_tlv(prefix_tlv.sub_tlvs, CheckType.CONTAIN, network_data.BorderRouter) lowpan_id_tlv = assert_contains_tlv(prefix_tlv.sub_tlvs, CheckType.CONTAIN, network_data.LowpanId) result = True if self._prefix is not None: result &= (self._prefix == binascii.hexlify(prefix_tlv.prefix)) if self._border_router_16 is not None: result &= (self._border_router_16 == border_router_tlv.border_router_16) return result class PrefixesCheck: def __init__(self, prefix_cnt=0, prefix_check_list=[]): self._prefix_cnt = prefix_cnt self._prefix_check_list = prefix_check_list def check(self, prefix_tlvs): # if prefix_cnt is given, then check count only if self._prefix_cnt > 0: assert len(prefix_tlvs) >= self._prefix_cnt, 'prefix count is less than expected' else: for prefix_check in self._prefix_check_list: found = False for prefix_tlv in prefix_tlvs: if prefix_check.check(prefix_tlv): found = True break assert found, 'Some prefix is absent: {}'.format(prefix_check) class CommissioningDataCheck: def __init__(self, stable=None, sub_tlv_type_list=[]): self._stable = stable self._sub_tlv_type_list = sub_tlv_type_list def check(self, commissioning_data_tlv): if self._stable is not None: assert self._stable == commissioning_data_tlv.stable, 'Commissioning Data stable flag is not correct' assert contains_tlvs(commissioning_data_tlv.sub_tlvs, self._sub_tlv_type_list), 'Some sub tlvs are missing in Commissioning Data' class NetworkDataCheck: def __init__(self, prefixes_check=None, commissioning_data_check=None): self._prefixes_check = prefixes_check self._commissioning_data_check = commissioning_data_check def check(self, network_data_tlv): if self._prefixes_check is not None: prefix_tlvs = [tlv for tlv in network_data_tlv.tlvs if isinstance(tlv, network_data.Prefix)] self._prefixes_check.check(prefix_tlvs) if self._commissioning_data_check is not None: commissioning_data_tlv = assert_contains_tlv(network_data_tlv.tlvs, CheckType.CONTAIN, network_data.CommissioningData) self._commissioning_data_check.check(commissioning_data_tlv)