[mle] refactor MleRouter and Mle classes into a single Mle class (#11411)

This commit refactors the `Mle` modules and combines the `MleRouter`
and `Mle` classes into a single `Mle` class which now handles both
FTD and MTD functionalities.

The `MleRouter` and `Mle` classes were originally intended as
sub-classes, where the base class `Mle` would provide MTD and common
behaviors, and `MleRouter` would implement FTD-specific behaviors.
However, over the years and as new features were implemented, these
two classes became more intertwined, and the `Mle` class began to
include many FTD-related functions and interactions with `MleRouter`
private variables and methods.

This commit simplifies the code by combining the two into a single
class. The previous `mle_router.cpp` file is also renamed to
`mle_ftd.cpp` to indicate that it implements FTD-specific MLE
behaviors.
This commit is contained in:
Abtin Keshavarzian
2025-04-18 14:28:47 -07:00
committed by GitHub
parent f70749d21d
commit 0c1dfa0796
67 changed files with 1135 additions and 1255 deletions
+36 -45
View File
@@ -39,14 +39,11 @@
using namespace ot;
uint32_t otThreadGetChildTimeout(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().GetTimeout();
}
uint32_t otThreadGetChildTimeout(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::Mle>().GetTimeout(); }
void otThreadSetChildTimeout(otInstance *aInstance, uint32_t aTimeout)
{
AsCoreType(aInstance).Get<Mle::MleRouter>().SetTimeout(aTimeout);
AsCoreType(aInstance).Get<Mle::Mle>().SetTimeout(aTimeout);
}
const otExtendedPanId *otThreadGetExtendedPanId(otInstance *aInstance)
@@ -60,7 +57,7 @@ otError otThreadSetExtendedPanId(otInstance *aInstance, const otExtendedPanId *a
Instance &instance = AsCoreType(aInstance);
const MeshCoP::ExtendedPanId &extPanId = AsCoreType(aExtendedPanId);
VerifyOrExit(instance.Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(instance.Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
instance.Get<MeshCoP::ExtendedPanIdManager>().SetExtPanId(extPanId);
@@ -86,14 +83,14 @@ otLinkModeConfig otThreadGetLinkMode(otInstance *aInstance)
{
otLinkModeConfig config;
AsCoreType(aInstance).Get<Mle::MleRouter>().GetDeviceMode().Get(config);
AsCoreType(aInstance).Get<Mle::Mle>().GetDeviceMode().Get(config);
return config;
}
otError otThreadSetLinkMode(otInstance *aInstance, otLinkModeConfig aConfig)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().SetDeviceMode(Mle::DeviceMode(aConfig));
return AsCoreType(aInstance).Get<Mle::Mle>().SetDeviceMode(Mle::DeviceMode(aConfig));
}
void otThreadGetNetworkKey(otInstance *aInstance, otNetworkKey *aNetworkKey)
@@ -113,7 +110,7 @@ otError otThreadSetNetworkKey(otInstance *aInstance, const otNetworkKey *aKey)
Error error = kErrorNone;
Instance &instance = AsCoreType(aInstance);
VerifyOrExit(instance.Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(instance.Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
instance.Get<KeyManager>().SetNetworkKey(AsCoreType(aKey));
@@ -132,7 +129,7 @@ otError otThreadSetNetworkKeyRef(otInstance *aInstance, otNetworkKeyRef aKeyRef)
VerifyOrExit(aKeyRef != 0, error = kErrorInvalidArgs);
VerifyOrExit(instance.Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(instance.Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
instance.Get<KeyManager>().SetNetworkKeyRef((aKeyRef));
instance.Get<MeshCoP::ActiveDatasetManager>().Clear();
@@ -145,26 +142,26 @@ exit:
const otIp6Address *otThreadGetRloc(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetMeshLocalRloc();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetMeshLocalRloc();
}
const otIp6Address *otThreadGetMeshLocalEid(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetMeshLocalEid();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetMeshLocalEid();
}
const otMeshLocalPrefix *otThreadGetMeshLocalPrefix(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetMeshLocalPrefix();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetMeshLocalPrefix();
}
otError otThreadSetMeshLocalPrefix(otInstance *aInstance, const otMeshLocalPrefix *aMeshLocalPrefix)
{
Error error = kErrorNone;
VerifyOrExit(AsCoreType(aInstance).Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(AsCoreType(aInstance).Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
AsCoreType(aInstance).Get<Mle::MleRouter>().SetMeshLocalPrefix(AsCoreType(aMeshLocalPrefix));
AsCoreType(aInstance).Get<Mle::Mle>().SetMeshLocalPrefix(AsCoreType(aMeshLocalPrefix));
AsCoreType(aInstance).Get<MeshCoP::ActiveDatasetManager>().Clear();
AsCoreType(aInstance).Get<MeshCoP::PendingDatasetManager>().Clear();
@@ -174,17 +171,17 @@ exit:
const otIp6Address *otThreadGetLinkLocalIp6Address(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetLinkLocalAddress();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetLinkLocalAddress();
}
const otIp6Address *otThreadGetLinkLocalAllThreadNodesMulticastAddress(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetLinkLocalAllThreadNodesAddress();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetLinkLocalAllThreadNodesAddress();
}
const otIp6Address *otThreadGetRealmLocalAllThreadNodesMulticastAddress(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetRealmLocalAllThreadNodesAddress();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetRealmLocalAllThreadNodesAddress();
}
otError otThreadGetServiceAloc(otInstance *aInstance, uint8_t aServiceId, otIp6Address *aServiceAloc)
@@ -207,7 +204,7 @@ otError otThreadSetNetworkName(otInstance *aInstance, const char *aNetworkName)
{
Error error = kErrorNone;
VerifyOrExit(AsCoreType(aInstance).Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(AsCoreType(aInstance).Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
#if !OPENTHREAD_CONFIG_ALLOW_EMPTY_NETWORK_NAME
// Thread interfaces support a zero length name internally for backwards compatibility, but new names
@@ -233,7 +230,7 @@ otError otThreadSetDomainName(otInstance *aInstance, const char *aDomainName)
{
Error error = kErrorNone;
VerifyOrExit(AsCoreType(aInstance).Get<Mle::MleRouter>().IsDisabled(), error = kErrorInvalidState);
VerifyOrExit(AsCoreType(aInstance).Get<Mle::Mle>().IsDisabled(), error = kErrorInvalidState);
error = AsCoreType(aInstance).Get<MeshCoP::NetworkNameManager>().SetDomainName(aDomainName);
@@ -295,12 +292,9 @@ void otThreadSetKeySwitchGuardTime(otInstance *aInstance, uint16_t aKeySwitchGua
AsCoreType(aInstance).Get<KeyManager>().SetKeySwitchGuardTime(aKeySwitchGuardTime);
}
otError otThreadBecomeDetached(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().BecomeDetached();
}
otError otThreadBecomeDetached(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::Mle>().BecomeDetached(); }
otError otThreadBecomeChild(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::MleRouter>().BecomeChild(); }
otError otThreadBecomeChild(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::Mle>().BecomeChild(); }
otError otThreadGetNextNeighborInfo(otInstance *aInstance, otNeighborInfoIterator *aIterator, otNeighborInfo *aInfo)
{
@@ -311,7 +305,7 @@ otError otThreadGetNextNeighborInfo(otInstance *aInstance, otNeighborInfoIterato
otDeviceRole otThreadGetDeviceRole(otInstance *aInstance)
{
return MapEnum(AsCoreType(aInstance).Get<Mle::MleRouter>().GetRole());
return MapEnum(AsCoreType(aInstance).Get<Mle::Mle>().GetRole());
}
const char *otThreadDeviceRoleToString(otDeviceRole aRole) { return Mle::RoleToString(MapEnum(aRole)); }
@@ -322,29 +316,26 @@ otError otThreadGetLeaderData(otInstance *aInstance, otLeaderData *aLeaderData)
AssertPointerIsNotNull(aLeaderData);
VerifyOrExit(AsCoreType(aInstance).Get<Mle::MleRouter>().IsAttached(), error = kErrorDetached);
*aLeaderData = AsCoreType(aInstance).Get<Mle::MleRouter>().GetLeaderData();
VerifyOrExit(AsCoreType(aInstance).Get<Mle::Mle>().IsAttached(), error = kErrorDetached);
*aLeaderData = AsCoreType(aInstance).Get<Mle::Mle>().GetLeaderData();
exit:
return error;
}
uint8_t otThreadGetLeaderRouterId(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().GetLeaderId();
}
uint8_t otThreadGetLeaderRouterId(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::Mle>().GetLeaderId(); }
uint8_t otThreadGetLeaderWeight(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().GetLeaderData().GetWeighting();
return AsCoreType(aInstance).Get<Mle::Mle>().GetLeaderData().GetWeighting();
}
uint32_t otThreadGetPartitionId(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().GetLeaderData().GetPartitionId();
return AsCoreType(aInstance).Get<Mle::Mle>().GetLeaderData().GetPartitionId();
}
uint16_t otThreadGetRloc16(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::MleRouter>().GetRloc16(); }
uint16_t otThreadGetRloc16(otInstance *aInstance) { return AsCoreType(aInstance).Get<Mle::Mle>().GetRloc16(); }
otError otThreadGetParentInfo(otInstance *aInstance, otRouterInfo *aParentInfo)
{
@@ -357,7 +348,7 @@ otError otThreadGetParentAverageRssi(otInstance *aInstance, int8_t *aParentRssi)
AssertPointerIsNotNull(aParentRssi);
*aParentRssi = AsCoreType(aInstance).Get<Mle::MleRouter>().GetParent().GetLinkInfo().GetAverageRss();
*aParentRssi = AsCoreType(aInstance).Get<Mle::Mle>().GetParent().GetLinkInfo().GetAverageRss();
VerifyOrExit(*aParentRssi != Radio::kInvalidRssi, error = kErrorFailed);
@@ -371,7 +362,7 @@ otError otThreadGetParentLastRssi(otInstance *aInstance, int8_t *aLastRssi)
AssertPointerIsNotNull(aLastRssi);
*aLastRssi = AsCoreType(aInstance).Get<Mle::MleRouter>().GetParent().GetLinkInfo().GetLastRss();
*aLastRssi = AsCoreType(aInstance).Get<Mle::Mle>().GetParent().GetLinkInfo().GetLastRss();
VerifyOrExit(*aLastRssi != Radio::kInvalidRssi, error = kErrorFailed);
@@ -390,11 +381,11 @@ otError otThreadSetEnabled(otInstance *aInstance, bool aEnabled)
if (aEnabled)
{
error = AsCoreType(aInstance).Get<Mle::MleRouter>().Start();
error = AsCoreType(aInstance).Get<Mle::Mle>().Start();
}
else
{
AsCoreType(aInstance).Get<Mle::MleRouter>().Stop();
AsCoreType(aInstance).Get<Mle::Mle>().Stop();
}
return error;
@@ -407,7 +398,7 @@ bool otThreadIsSingleton(otInstance *aInstance)
bool isSingleton = false;
#if OPENTHREAD_FTD
isSingleton = AsCoreType(aInstance).Get<Mle::MleRouter>().IsSingleton();
isSingleton = AsCoreType(aInstance).Get<Mle::Mle>().IsSingleton();
#else
OT_UNUSED_VARIABLE(aInstance);
#endif
@@ -470,14 +461,14 @@ void otThreadResetTimeInQueueStat(otInstance *aInstance)
const otMleCounters *otThreadGetMleCounters(otInstance *aInstance)
{
return &AsCoreType(aInstance).Get<Mle::MleRouter>().GetCounters();
return &AsCoreType(aInstance).Get<Mle::Mle>().GetCounters();
}
void otThreadResetMleCounters(otInstance *aInstance) { AsCoreType(aInstance).Get<Mle::MleRouter>().ResetCounters(); }
void otThreadResetMleCounters(otInstance *aInstance) { AsCoreType(aInstance).Get<Mle::Mle>().ResetCounters(); }
uint32_t otThreadGetCurrentAttachDuration(otInstance *aInstance)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().GetCurrentAttachDuration();
return AsCoreType(aInstance).Get<Mle::Mle>().GetCurrentAttachDuration();
}
#if OPENTHREAD_CONFIG_MLE_PARENT_RESPONSE_CALLBACK_API_ENABLE
@@ -485,7 +476,7 @@ void otThreadRegisterParentResponseCallback(otInstance *aInst
otThreadParentResponseCallback aCallback,
void *aContext)
{
AsCoreType(aInstance).Get<Mle::MleRouter>().RegisterParentResponseStatsCallback(aCallback, aContext);
AsCoreType(aInstance).Get<Mle::Mle>().RegisterParentResponseStatsCallback(aCallback, aContext);
}
#endif
@@ -506,7 +497,7 @@ bool otThreadIsAnycastLocateInProgress(otInstance *aInstance)
otError otThreadDetachGracefully(otInstance *aInstance, otDetachGracefullyCallback aCallback, void *aContext)
{
return AsCoreType(aInstance).Get<Mle::MleRouter>().DetachGracefully(aCallback, aContext);
return AsCoreType(aInstance).Get<Mle::Mle>().DetachGracefully(aCallback, aContext);
}
#if OPENTHREAD_CONFIG_DYNAMIC_STORE_FRAME_AHEAD_COUNTER_ENABLE