mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Fix: missing clientId when serialize and deserialize response (#5231)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
7246fd75d1
commit
113f6fbadd
@ -946,15 +946,18 @@ Response Serialization::deserializeResponse(std::istream& is)
|
||||
{
|
||||
auto requestId = su::deserialize<IdType>(is);
|
||||
auto errOrResult = su::deserialize<std::variant<std::string, Result>>(is);
|
||||
auto clientId = su::deserialize<std::optional<IdType>>(is);
|
||||
|
||||
return std::holds_alternative<std::string>(errOrResult) ? Response{requestId, std::get<std::string>(errOrResult)}
|
||||
: Response{requestId, std::get<Result>(errOrResult)};
|
||||
return std::holds_alternative<std::string>(errOrResult)
|
||||
? Response{requestId, std::get<std::string>(errOrResult), clientId}
|
||||
: Response{requestId, std::get<Result>(errOrResult), clientId};
|
||||
}
|
||||
|
||||
void Serialization::serialize(Response const& response, std::ostream& os)
|
||||
{
|
||||
su::serialize(response.mImpl->mRequestId, os);
|
||||
su::serialize(response.mImpl->mErrOrResult, os);
|
||||
su::serialize(response.mImpl->mClientId, os);
|
||||
}
|
||||
|
||||
size_t Serialization::serializedSize(Response const& response)
|
||||
@ -962,6 +965,7 @@ size_t Serialization::serializedSize(Response const& response)
|
||||
size_t totalSize = 0;
|
||||
totalSize += su::serializedSize(response.mImpl->mRequestId);
|
||||
totalSize += su::serializedSize(response.mImpl->mErrOrResult);
|
||||
totalSize += su::serializedSize(response.mImpl->mClientId);
|
||||
return totalSize;
|
||||
}
|
||||
|
||||
|
||||
@ -160,6 +160,7 @@ void compareResponse(texec::Response res, texec::Response res2)
|
||||
{
|
||||
compareResult(res.getResult(), res2.getResult());
|
||||
}
|
||||
EXPECT_EQ(res.getClientId(), res2.getClientId());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -428,11 +429,15 @@ TEST(SerializeUtilsTest, ResultResponse)
|
||||
auto val = texec::Response(1, "my error msg");
|
||||
testSerializeDeserialize(val);
|
||||
}
|
||||
{
|
||||
auto val = texec::Response(1, "my error msg", 2);
|
||||
testSerializeDeserialize(val);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializeUtilsTest, VectorResponses)
|
||||
{
|
||||
int numResponses = 10;
|
||||
int numResponses = 15;
|
||||
std::vector<texec::Response> responsesIn;
|
||||
for (int i = 0; i < numResponses; ++i)
|
||||
{
|
||||
@ -443,11 +448,16 @@ TEST(SerializeUtilsTest, VectorResponses)
|
||||
std::nullopt, std::vector<texec::FinishReason>{texec::FinishReason::kEND_ID}};
|
||||
responsesIn.emplace_back(i, res);
|
||||
}
|
||||
else
|
||||
else if (i < 10)
|
||||
{
|
||||
std::string errMsg = "my_err_msg" + std::to_string(i);
|
||||
responsesIn.emplace_back(i, errMsg);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string errMsg = "my_err_msg" + std::to_string(i);
|
||||
responsesIn.emplace_back(i, errMsg, i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
auto buffer = texec::Serialization::serialize(responsesIn);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user