feat: support redis 7.0 shared pub and sub (#28333)

This commit is contained in:
wangxiaolei 2025-11-21 10:33:52 +08:00 committed by GitHub
parent e260815c5e
commit cad2991946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1532 additions and 164 deletions

View File

@ -1,3 +1,4 @@
from .channel import BroadcastChannel
from .sharded_channel import ShardedRedisBroadcastChannel
__all__ = ["BroadcastChannel"]
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]

View File

@ -0,0 +1,205 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class RedisSubscriptionBase(Subscription):
"""Base class for Redis pub/sub subscriptions with common functionality.
This class provides shared functionality for both regular and sharded
Redis pub/sub subscriptions, reducing code duplication and improving
maintainability.
"""
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError(
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
)
self._subscribe()
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
"""Main listener loop for processing messages."""
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = self._get_message()
if raw_message is None:
continue
if raw_message.get("type") != self._get_message_type():
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error(
"Received invalid data from %s channel %s, type=%s",
self._get_subscription_type(),
self._topic,
type(payload_bytes),
)
continue
self._enqueue_message(payload_bytes)
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
self._unsubscribe()
pubsub.close()
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
"""Enqueue a message to the internal queue with dropping behavior."""
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
self._get_subscription_type(),
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
"""Iterator for consuming messages from the subscription."""
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
"""Return an iterator over messages from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
"""Context manager entry point."""
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
"""Context manager exit point."""
self.close()
return None
def close(self) -> None:
"""Close the subscription and clean up resources."""
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
# message retrieval method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
def _get_subscription_type(self) -> str:
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
raise NotImplementedError
def _subscribe(self) -> None:
"""Subscribe to the Redis topic using the appropriate command."""
raise NotImplementedError
def _unsubscribe(self) -> None:
"""Unsubscribe from the Redis topic using the appropriate command."""
raise NotImplementedError
def _get_message(self) -> dict | None:
"""Get a message from Redis using the appropriate method."""
raise NotImplementedError
def _get_message_type(self) -> str:
"""Return the expected message type (e.g., 'message' or 'smessage')."""
raise NotImplementedError

View File

@ -1,24 +1,15 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis
from redis.client import PubSub
_logger = logging.getLogger(__name__)
from ._subscription import RedisSubscriptionBase
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation.
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
Provides "at most once" delivery semantics for messages published to channels.
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
Provides "at most once" delivery semantics for messages published to channels
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
@ -54,147 +45,23 @@ class Topic:
)
class _RedisSubscription(Subscription):
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
class _RedisSubscription(RedisSubscriptionBase):
"""Regular Redis pub/sub subscription implementation."""
def _start_if_needed(self) -> None:
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
def _get_subscription_type(self) -> str:
return "regular"
self._pubsub.subscribe(self._topic)
_logger.debug("Subscribed to channel %s", self._topic)
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.subscribe(self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.unsubscribe(self._topic)
def _listen(self) -> None:
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if raw_message is None:
continue
if raw_message.get("type") != "message":
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
continue
self._enqueue_message(payload_bytes)
_logger.debug("Listener thread stopped for channel %s", self._topic)
pubsub.unsubscribe(self._topic)
pubsub.close()
_logger.debug("PubSub closed for topic %s", self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
# method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
def _get_message_type(self) -> str:
return "message"

View File

@ -0,0 +1,65 @@
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from ._subscription import RedisSubscriptionBase
class ShardedRedisBroadcastChannel:
"""
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
distributing channels across Redis cluster nodes for better scalability.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "ShardedTopic":
return ShardedTopic(self._client, topic)
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisShardedSubscription(RedisSubscriptionBase):
"""Redis 7.0+ sharded pub/sub subscription implementation."""
def _get_subscription_type(self) -> str:
return "sharded"
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"

View File

@ -107,7 +107,11 @@ class TestRedisBroadcastChannelIntegration:
assert received_messages[0] == message
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
"""Test message broadcasting to multiple subscribers."""
"""Test message broadcasting to multiple subscribers.
This test ensures the publisher only sends after all subscribers have actually started
their Redis Pub/Sub subscriptions to avoid race conditions/flakiness.
"""
topic_name = "broadcast-topic"
message = b"broadcast message"
subscriber_count = 5
@ -116,16 +120,33 @@ class TestRedisBroadcastChannelIntegration:
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
ready_events = [threading.Event() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
# Wait for all subscribers to start (with a reasonable timeout)
deadline = time.time() + 5.0
for ev in ready_events:
remaining = deadline - time.time()
if remaining <= 0:
break
ev.wait(timeout=max(0.0, remaining))
# Now publish the message
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]:
received_msgs = []
# Prime the subscription to ensure the underlying Pub/Sub is started
try:
_ = subscription.receive(0.01)
except SubscriptionClosedError:
ready_event.set()
return received_msgs
# Signal readiness after first receive returns (subscription started)
ready_event.set()
while True:
try:
msg = subscription.receive(0.1)
@ -141,7 +162,10 @@ class TestRedisBroadcastChannelIntegration:
# Run producer and consumers
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
consumer_futures = [
executor.submit(consumer_thread, subscription, ready_events[idx])
for idx, subscription in enumerate(subscriptions)
]
# Wait for completion
producer_future.result(timeout=10.0)

View File

@ -0,0 +1,317 @@
"""
Integration tests for Redis sharded broadcast channel implementation using TestContainers.
Covers real Redis 7+ sharded pub/sub interactions including:
- Multiple producer/consumer scenarios
- Topic isolation
- Concurrency under load
- Resource cleanup accounting via PUBSUB SHARDNUMSUB
"""
import threading
import time
import uuid
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
import redis
from testcontainers.redis import RedisContainer
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.redis.sharded_channel import (
ShardedRedisBroadcastChannel,
)
class TestShardedRedisBroadcastChannelIntegration:
"""Integration tests for Redis sharded broadcast channel with real Redis 7 instance."""
@pytest.fixture(scope="class")
def redis_container(self) -> Iterator[RedisContainer]:
"""Create a Redis 7 container for integration testing (required for sharded pub/sub)."""
# Redis 7+ is required for SPUBLISH/SSUBSCRIBE
with RedisContainer(image="redis:7-alpine") as container:
yield container
@pytest.fixture(scope="class")
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
"""Create a Redis client connected to the test container."""
host = redis_container.get_container_host_ip()
port = redis_container.get_exposed_port(6379)
return redis.Redis(host=host, port=port, decode_responses=False)
@pytest.fixture
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
"""Create a ShardedRedisBroadcastChannel instance with real Redis client."""
return ShardedRedisBroadcastChannel(redis_client)
@classmethod
def _get_test_topic_name(cls) -> str:
return f"test_sharded_topic_{uuid.uuid4()}"
# ==================== Basic Functionality Tests ====================
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel):
topic_name = self._get_test_topic_name()
topic = broadcast_channel.topic(topic_name)
subscription = topic.subscribe()
consuming_event = threading.Event()
def consume():
msgs = []
consuming_event.set()
for msg in subscription:
msgs.append(msg)
return msgs
with ThreadPoolExecutor(max_workers=1) as executor:
consumer_future = executor.submit(consume)
consuming_event.wait()
subscription.close()
msgs = consumer_future.result(timeout=2)
assert msgs == []
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel):
"""Test complete end-to-end messaging flow (sharded)."""
topic_name = self._get_test_topic_name()
message = b"hello sharded world"
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscription = topic.subscribe()
def producer_thread():
time.sleep(0.1) # Small delay to ensure subscriber is ready
producer.publish(message)
time.sleep(0.1)
subscription.close()
def consumer_thread() -> list[bytes]:
received_messages = []
for msg in subscription:
received_messages.append(msg)
return received_messages
with ThreadPoolExecutor(max_workers=2) as executor:
producer_future = executor.submit(producer_thread)
consumer_future = executor.submit(consumer_thread)
producer_future.result(timeout=5.0)
received_messages = consumer_future.result(timeout=5.0)
assert len(received_messages) == 1
assert received_messages[0] == message
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
"""Test message broadcasting to multiple sharded subscribers."""
topic_name = self._get_test_topic_name()
message = b"broadcast sharded message"
subscriber_count = 5
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
received_msgs = []
while True:
try:
msg = subscription.receive(0.1)
except SubscriptionClosedError:
break
if msg is None:
continue
received_msgs.append(msg)
if len(received_msgs) >= 1:
break
return received_msgs
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
producer_future.result(timeout=10.0)
msgs_by_consumers = []
for future in as_completed(consumer_futures, timeout=10.0):
msgs_by_consumers.append(future.result())
for subscription in subscriptions:
subscription.close()
for msgs in msgs_by_consumers:
assert len(msgs) == 1
assert msgs[0] == message
def test_topic_isolation(self, broadcast_channel: BroadcastChannel):
"""Test that different sharded topics are isolated from each other."""
topic1_name = self._get_test_topic_name()
topic2_name = self._get_test_topic_name()
message1 = b"message for sharded topic1"
message2 = b"message for sharded topic2"
topic1 = broadcast_channel.topic(topic1_name)
topic2 = broadcast_channel.topic(topic2_name)
def producer_thread():
time.sleep(0.1)
topic1.publish(message1)
topic2.publish(message2)
def consumer_by_thread(topic: Topic) -> list[bytes]:
subscription = topic.subscribe()
received = []
with subscription:
for msg in subscription:
received.append(msg)
if len(received) >= 1:
break
return received
with ThreadPoolExecutor(max_workers=3) as executor:
producer_future = executor.submit(producer_thread)
consumer1_future = executor.submit(consumer_by_thread, topic1)
consumer2_future = executor.submit(consumer_by_thread, topic2)
producer_future.result(timeout=5.0)
received_by_topic1 = consumer1_future.result(timeout=5.0)
received_by_topic2 = consumer2_future.result(timeout=5.0)
assert len(received_by_topic1) == 1
assert len(received_by_topic2) == 1
assert received_by_topic1[0] == message1
assert received_by_topic2[0] == message2
# ==================== Performance / Concurrency ====================
def test_concurrent_producers(self, broadcast_channel: BroadcastChannel):
"""Test multiple producers publishing to the same sharded topic."""
topic_name = self._get_test_topic_name()
producer_count = 5
messages_per_producer = 5
topic = broadcast_channel.topic(topic_name)
subscription = topic.subscribe()
expected_total = producer_count * messages_per_producer
consumer_ready = threading.Event()
def producer_thread(producer_idx: int) -> set[bytes]:
producer = topic.as_producer()
produced = set()
for i in range(messages_per_producer):
message = f"producer_{producer_idx}_msg_{i}".encode()
produced.add(message)
producer.publish(message)
time.sleep(0.001)
return produced
def consumer_thread() -> set[bytes]:
received_msgs: set[bytes] = set()
with subscription:
consumer_ready.set()
while True:
try:
msg = subscription.receive(timeout=0.1)
except SubscriptionClosedError:
break
if msg is None:
if len(received_msgs) >= expected_total:
break
else:
continue
received_msgs.add(msg)
return received_msgs
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
consumer_future = executor.submit(consumer_thread)
consumer_ready.wait()
producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)]
sent_msgs: set[bytes] = set()
for future in as_completed(producer_futures, timeout=30.0):
sent_msgs.update(future.result())
subscription.close()
consumer_received_msgs = consumer_future.result(timeout=30.0)
assert sent_msgs == consumer_received_msgs
# ==================== Resource Management ====================
def _get_sharded_numsub(self, redis_client: redis.Redis, topic_name: str) -> int:
"""Return number of sharded subscribers for a given topic using PUBSUB SHARDNUMSUB.
Redis returns a flat list like [channel1, count1, channel2, count2, ...].
We request a single channel, so parse accordingly.
"""
try:
res = redis_client.execute_command("PUBSUB", "SHARDNUMSUB", topic_name)
except Exception:
return 0
# Normalize different possible return shapes from drivers
if isinstance(res, (list, tuple)):
# Expect [channel, count] (bytes/str, int)
if len(res) >= 2:
key = res[0]
cnt = res[1]
if key == topic_name or (isinstance(key, (bytes, bytearray)) and key == topic_name.encode()):
try:
return int(cnt)
except Exception:
return 0
# Fallback parse pairs
count = 0
for i in range(0, len(res) - 1, 2):
key = res[i]
cnt = res[i + 1]
if key == topic_name or (isinstance(key, (bytes, bytearray)) and key == topic_name.encode()):
try:
count = int(cnt)
except Exception:
count = 0
break
return count
return 0
def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis):
"""Test proper cleanup of sharded subscription resources via SHARDNUMSUB."""
topic_name = self._get_test_topic_name()
topic = broadcast_channel.topic(topic_name)
def _consume(sub: Subscription):
for _ in sub:
pass
subscriptions = []
for _ in range(5):
subscription = topic.subscribe()
subscriptions.append(subscription)
thread = threading.Thread(target=_consume, args=(subscription,))
thread.start()
time.sleep(0.01)
# Verify subscriptions are active using SHARDNUMSUB
topic_subscribers = self._get_sharded_numsub(redis_client, topic_name)
assert topic_subscribers >= 5
# Close all subscriptions
for subscription in subscriptions:
subscription.close()
# Wait a bit for cleanup
time.sleep(1)
# Verify subscriptions are cleaned up
topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name)
assert topic_subscribers_after == 0

View File

@ -25,6 +25,11 @@ from libs.broadcast_channel.redis.channel import (
Topic,
_RedisSubscription,
)
from libs.broadcast_channel.redis.sharded_channel import (
ShardedRedisBroadcastChannel,
ShardedTopic,
_RedisShardedSubscription,
)
class TestBroadcastChannel:
@ -39,9 +44,14 @@ class TestBroadcastChannel:
@pytest.fixture
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
"""Create a BroadcastChannel instance with mock Redis client."""
"""Create a BroadcastChannel instance with mock Redis client (regular)."""
return RedisBroadcastChannel(mock_redis_client)
@pytest.fixture
def sharded_broadcast_channel(self, mock_redis_client: MagicMock) -> ShardedRedisBroadcastChannel:
"""Create a ShardedRedisBroadcastChannel instance with mock Redis client."""
return ShardedRedisBroadcastChannel(mock_redis_client)
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
"""Test that topic() method returns a Topic instance with correct parameters."""
topic_name = "test-topic"
@ -60,6 +70,38 @@ class TestBroadcastChannel:
assert topic1._topic == "topic1"
assert topic2._topic == "topic2"
def test_sharded_topic_creation(
self, sharded_broadcast_channel: ShardedRedisBroadcastChannel, mock_redis_client: MagicMock
):
"""Test that topic() on ShardedRedisBroadcastChannel returns a ShardedTopic instance with correct parameters."""
topic_name = "test-sharded-topic"
sharded_topic = sharded_broadcast_channel.topic(topic_name)
assert isinstance(sharded_topic, ShardedTopic)
assert sharded_topic._client == mock_redis_client
assert sharded_topic._topic == topic_name
def test_sharded_topic_isolation(self, sharded_broadcast_channel: ShardedRedisBroadcastChannel):
"""Test that different sharded topic names create isolated ShardedTopic instances."""
topic1 = sharded_broadcast_channel.topic("sharded-topic1")
topic2 = sharded_broadcast_channel.topic("sharded-topic2")
assert topic1 is not topic2
assert topic1._topic == "sharded-topic1"
assert topic2._topic == "sharded-topic2"
def test_regular_and_sharded_topic_isolation(
self, broadcast_channel: RedisBroadcastChannel, sharded_broadcast_channel: ShardedRedisBroadcastChannel
):
"""Test that regular topics and sharded topics from different channels are separate instances."""
regular_topic = broadcast_channel.topic("test-topic")
sharded_topic = sharded_broadcast_channel.topic("test-topic")
assert isinstance(regular_topic, Topic)
assert isinstance(sharded_topic, ShardedTopic)
assert regular_topic is not sharded_topic
assert regular_topic._topic == sharded_topic._topic
class TestTopic:
"""Test cases for the Topic class."""
@ -98,6 +140,51 @@ class TestTopic:
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
class TestShardedTopic:
"""Test cases for the ShardedTopic class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
"""Create a mock Redis client for testing."""
client = MagicMock()
client.pubsub.return_value = MagicMock()
return client
@pytest.fixture
def sharded_topic(self, mock_redis_client: MagicMock) -> ShardedTopic:
"""Create a ShardedTopic instance for testing."""
return ShardedTopic(mock_redis_client, "test-sharded-topic")
def test_as_producer_returns_self(self, sharded_topic: ShardedTopic):
"""Test that as_producer() returns self as Producer interface."""
producer = sharded_topic.as_producer()
assert producer is sharded_topic
# Producer is a Protocol, check duck typing instead
assert hasattr(producer, "publish")
def test_as_subscriber_returns_self(self, sharded_topic: ShardedTopic):
"""Test that as_subscriber() returns self as Subscriber interface."""
subscriber = sharded_topic.as_subscriber()
assert subscriber is sharded_topic
# Subscriber is a Protocol, check duck typing instead
assert hasattr(subscriber, "subscribe")
def test_publish_calls_redis_spublish(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
"""Test that publish() calls Redis SPUBLISH with correct parameters."""
payload = b"test sharded message"
sharded_topic.publish(payload)
mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload)
def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
"""Test that subscribe() returns a _RedisShardedSubscription instance."""
subscription = sharded_topic.subscribe()
assert isinstance(subscription, _RedisShardedSubscription)
assert subscription._pubsub is mock_redis_client.pubsub.return_value
assert subscription._topic == "test-sharded-topic"
@dataclasses.dataclass(frozen=True)
class SubscriptionTestCase:
"""Test case data for subscription tests."""
@ -175,14 +262,14 @@ class TestRedisSubscription:
"""Test that _start_if_needed() raises error when subscription is closed."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription is closed"):
subscription._start_if_needed()
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
"""Test that _start_if_needed() raises error when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription has been cleaned up"):
subscription._start_if_needed()
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
@ -250,7 +337,7 @@ class TestRedisSubscription:
"""Test that iterator raises error when subscription is closed."""
subscription.close()
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
with pytest.raises(BroadcastChannelError, match="The Redis regular subscription is closed"):
iter(subscription)
# ==================== Message Enqueue Tests ====================
@ -465,21 +552,21 @@ class TestRedisSubscription:
"""Test iterator behavior after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription is closed"):
iter(subscription)
def test_start_after_close(self, subscription: _RedisSubscription):
"""Test start attempts after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription is closed"):
subscription._start_if_needed()
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
"""Test operations when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
with pytest.raises(SubscriptionClosedError, match="The Redis regular subscription has been cleaned up"):
subscription._start_if_needed()
# Close should still work
@ -512,3 +599,805 @@ class TestRedisSubscription:
with pytest.raises(SubscriptionClosedError):
subscription.receive()
class TestRedisShardedSubscription:
"""Test cases for the _RedisShardedSubscription class."""
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
pubsub = MagicMock()
pubsub.ssubscribe = MagicMock()
pubsub.sunsubscribe = MagicMock()
pubsub.close = MagicMock()
pubsub.get_sharded_message = MagicMock()
return pubsub
@pytest.fixture
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
"""Create a _RedisShardedSubscription instance for testing."""
subscription = _RedisShardedSubscription(
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
yield subscription
subscription.close()
@pytest.fixture
def started_sharded_subscription(
self, sharded_subscription: _RedisShardedSubscription
) -> _RedisShardedSubscription:
"""Create a sharded subscription that has been started."""
sharded_subscription._start_if_needed()
return sharded_subscription
# ==================== Lifecycle Tests ====================
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
"""Test that sharded subscription is properly initialized."""
subscription = _RedisShardedSubscription(
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-sharded-topic"
assert not subscription._closed.is_set()
assert subscription._dropped_count == 0
assert subscription._listener_thread is None
assert not subscription._started
def test_start_if_needed_first_call(self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock):
"""Test that _start_if_needed() properly starts sharded subscription on first call."""
sharded_subscription._start_if_needed()
mock_pubsub.ssubscribe.assert_called_once_with("test-sharded-topic")
assert sharded_subscription._started is True
assert sharded_subscription._listener_thread is not None
def test_start_if_needed_subsequent_calls(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test that _start_if_needed() doesn't start sharded subscription on subsequent calls."""
original_thread = started_sharded_subscription._listener_thread
started_sharded_subscription._start_if_needed()
# Should not create new thread or generator
assert started_sharded_subscription._listener_thread is original_thread
def test_start_if_needed_when_closed(self, sharded_subscription: _RedisShardedSubscription):
"""Test that _start_if_needed() raises error when sharded subscription is closed."""
sharded_subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"):
sharded_subscription._start_if_needed()
def test_start_if_needed_when_cleaned_up(self, sharded_subscription: _RedisShardedSubscription):
"""Test that _start_if_needed() raises error when pubsub is None."""
sharded_subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription has been cleaned up"):
sharded_subscription._start_if_needed()
def test_context_manager_usage(self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock):
"""Test that sharded subscription works as context manager."""
with sharded_subscription as sub:
assert sub is sharded_subscription
assert sharded_subscription._started is True
mock_pubsub.ssubscribe.assert_called_once_with("test-sharded-topic")
def test_close_idempotent(self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock):
"""Test that close() is idempotent and can be called multiple times."""
sharded_subscription._start_if_needed()
# Close multiple times
sharded_subscription.close()
sharded_subscription.close()
sharded_subscription.close()
# Should only cleanup once
mock_pubsub.sunsubscribe.assert_called_once_with("test-sharded-topic")
mock_pubsub.close.assert_called_once()
assert sharded_subscription._pubsub is None
assert sharded_subscription._closed.is_set()
def test_close_cleanup(self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock):
"""Test that close() properly cleans up all resources."""
sharded_subscription._start_if_needed()
thread = sharded_subscription._listener_thread
sharded_subscription.close()
# Verify cleanup
mock_pubsub.sunsubscribe.assert_called_once_with("test-sharded-topic")
mock_pubsub.close.assert_called_once()
assert sharded_subscription._pubsub is None
assert sharded_subscription._listener_thread is None
# Wait for thread to finish (with timeout)
if thread and thread.is_alive():
thread.join(timeout=1.0)
assert not thread.is_alive()
# ==================== Message Processing Tests ====================
def test_message_iterator_with_messages(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test message iterator behavior with messages in queue."""
test_messages = [b"sharded_msg1", b"sharded_msg2", b"sharded_msg3"]
# Add messages to queue
for msg in test_messages:
started_sharded_subscription._queue.put_nowait(msg)
# Iterate through messages
iterator = iter(started_sharded_subscription)
received_messages = []
for msg in iterator:
received_messages.append(msg)
if len(received_messages) >= len(test_messages):
break
assert received_messages == test_messages
def test_message_iterator_when_closed(self, sharded_subscription: _RedisShardedSubscription):
"""Test that iterator raises error when sharded subscription is closed."""
sharded_subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"):
iter(sharded_subscription)
# ==================== Message Enqueue Tests ====================
def test_enqueue_message_success(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test successful message enqueue."""
payload = b"test sharded message"
started_sharded_subscription._enqueue_message(payload)
assert started_sharded_subscription._queue.qsize() == 1
assert started_sharded_subscription._queue.get_nowait() == payload
def test_enqueue_message_when_closed(self, sharded_subscription: _RedisShardedSubscription):
"""Test message enqueue when sharded subscription is closed."""
sharded_subscription.close()
payload = b"test sharded message"
# Should not raise exception, but should not enqueue
sharded_subscription._enqueue_message(payload)
assert sharded_subscription._queue.empty()
def test_enqueue_message_with_full_queue(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test message enqueue with full queue (dropping behavior)."""
# Fill the queue
for i in range(started_sharded_subscription._queue.maxsize):
started_sharded_subscription._queue.put_nowait(f"old_msg_{i}".encode())
# Try to enqueue new message (should drop oldest)
new_message = b"new_sharded_message"
started_sharded_subscription._enqueue_message(new_message)
# Should have dropped one message and added new one
assert started_sharded_subscription._dropped_count == 1
# New message should be in queue
messages = []
while not started_sharded_subscription._queue.empty():
messages.append(started_sharded_subscription._queue.get_nowait())
assert new_message in messages
# ==================== Listener Thread Tests ====================
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
def test_listener_thread_normal_operation(
self, mock_sleep, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test sharded listener thread normal operation."""
# Mock sharded message from Redis
mock_message = {"type": "smessage", "channel": "test-sharded-topic", "data": b"test sharded payload"}
mock_pubsub.get_sharded_message.return_value = mock_message
# Start listener
sharded_subscription._start_if_needed()
# Wait a bit for processing
time.sleep(0.1)
# Verify message was processed
assert not sharded_subscription._queue.empty()
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
def test_listener_thread_ignores_subscribe_messages(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test that listener thread ignores ssubscribe/sunsubscribe messages."""
mock_message = {"type": "ssubscribe", "channel": "test-sharded-topic", "data": 1}
mock_pubsub.get_sharded_message.return_value = mock_message
sharded_subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue ssubscribe messages
assert sharded_subscription._queue.empty()
def test_listener_thread_ignores_wrong_channel(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test that listener thread ignores messages from wrong channels."""
mock_message = {"type": "smessage", "channel": "wrong-sharded-topic", "data": b"test payload"}
mock_pubsub.get_sharded_message.return_value = mock_message
sharded_subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue messages from wrong channels
assert sharded_subscription._queue.empty()
def test_listener_thread_ignores_regular_messages(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test that listener thread ignores regular (non-sharded) messages."""
mock_message = {"type": "message", "channel": "test-sharded-topic", "data": b"test payload"}
mock_pubsub.get_sharded_message.return_value = mock_message
sharded_subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue regular messages in sharded subscription
assert sharded_subscription._queue.empty()
def test_listener_thread_handles_redis_exceptions(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test that listener thread handles Redis exceptions gracefully."""
mock_pubsub.get_sharded_message.side_effect = Exception("Redis error")
sharded_subscription._start_if_needed()
# Wait for thread to handle exception
time.sleep(0.2)
# Thread should still be alive but not processing
assert sharded_subscription._listener_thread is not None
assert not sharded_subscription._listener_thread.is_alive()
def test_listener_thread_stops_when_closed(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
"""Test that listener thread stops when sharded subscription is closed."""
sharded_subscription._start_if_needed()
thread = sharded_subscription._listener_thread
# Close subscription
sharded_subscription.close()
# Wait for thread to finish
if thread is not None and thread.is_alive():
thread.join(timeout=1.0)
assert thread is None or not thread.is_alive()
# ==================== Table-driven Tests ====================
@pytest.mark.parametrize(
"test_case",
[
SubscriptionTestCase(
name="basic_sharded_message",
buffer_size=5,
payload=b"hello sharded world",
expected_messages=[b"hello sharded world"],
description="Basic sharded message publishing and receiving",
),
SubscriptionTestCase(
name="empty_sharded_message",
buffer_size=5,
payload=b"",
expected_messages=[b""],
description="Empty sharded message handling",
),
SubscriptionTestCase(
name="large_sharded_message",
buffer_size=5,
payload=b"x" * 10000,
expected_messages=[b"x" * 10000],
description="Large sharded message handling",
),
SubscriptionTestCase(
name="unicode_sharded_message",
buffer_size=5,
payload="你好世界".encode(),
expected_messages=["你好世界".encode()],
description="Unicode sharded message handling",
),
],
)
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
"""Test various sharded subscription scenarios using table-driven approach."""
subscription = _RedisShardedSubscription(
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
# Simulate receiving sharded message
mock_message = {"type": "smessage", "channel": "test-sharded-topic", "data": test_case.payload}
mock_pubsub.get_sharded_message.return_value = mock_message
try:
with subscription:
# Wait for message processing
time.sleep(0.1)
# Collect received messages
received = []
for msg in subscription:
received.append(msg)
if len(received) >= len(test_case.expected_messages):
break
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
finally:
subscription.close()
def test_concurrent_close_and_enqueue(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test concurrent close and enqueue operations for sharded subscription."""
errors = []
def close_subscription():
try:
time.sleep(0.05) # Small delay
started_sharded_subscription.close()
except Exception as e:
errors.append(e)
def enqueue_messages():
try:
for i in range(50):
started_sharded_subscription._enqueue_message(f"sharded_msg_{i}".encode())
time.sleep(0.001)
except Exception as e:
errors.append(e)
# Start threads
close_thread = threading.Thread(target=close_subscription)
enqueue_thread = threading.Thread(target=enqueue_messages)
close_thread.start()
enqueue_thread.start()
# Wait for completion
close_thread.join(timeout=2.0)
enqueue_thread.join(timeout=2.0)
# Should not have any errors (operations should be safe)
assert len(errors) == 0
# ==================== Error Handling Tests ====================
def test_iterator_after_close(self, sharded_subscription: _RedisShardedSubscription):
"""Test iterator behavior after close for sharded subscription."""
sharded_subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"):
iter(sharded_subscription)
def test_start_after_close(self, sharded_subscription: _RedisShardedSubscription):
"""Test start attempts after close for sharded subscription."""
sharded_subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription is closed"):
sharded_subscription._start_if_needed()
def test_pubsub_none_operations(self, sharded_subscription: _RedisShardedSubscription):
"""Test operations when pubsub is None for sharded subscription."""
sharded_subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis sharded subscription has been cleaned up"):
sharded_subscription._start_if_needed()
# Close should still work
sharded_subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
"""Test various sharded channel name formats."""
channel_names = [
"simple",
"with-dashes",
"with_underscores",
"with.numbers",
"WITH.UPPERCASE",
"mixed-CASE_name",
"very.long.sharded.channel.name.with.multiple.parts",
]
for channel_name in channel_names:
subscription = _RedisShardedSubscription(
pubsub=mock_pubsub,
topic=channel_name,
)
subscription._start_if_needed()
mock_pubsub.ssubscribe.assert_called_with(channel_name)
subscription.close()
def test_receive_on_closed_sharded_subscription(self, sharded_subscription: _RedisShardedSubscription):
"""Test receive method on closed sharded subscription."""
sharded_subscription.close()
with pytest.raises(SubscriptionClosedError):
sharded_subscription.receive()
def test_receive_with_timeout(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test receive method with timeout for sharded subscription."""
# Should return None when no message available and timeout expires
result = started_sharded_subscription.receive(timeout=0.01)
assert result is None
def test_receive_with_message(self, started_sharded_subscription: _RedisShardedSubscription):
"""Test receive method when message is available for sharded subscription."""
test_message = b"test sharded receive"
started_sharded_subscription._queue.put_nowait(test_message)
result = started_sharded_subscription.receive(timeout=1.0)
assert result == test_message
class TestRedisSubscriptionCommon:
"""Parameterized tests for common Redis subscription functionality.
This test suite eliminates duplication by running the same tests against
both regular and sharded subscriptions using pytest.mark.parametrize.
"""
@pytest.fixture(
params=[
("regular", _RedisSubscription),
("sharded", _RedisShardedSubscription),
]
)
def subscription_params(self, request):
"""Parameterized fixture providing subscription type and class."""
return request.param
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
pubsub = MagicMock()
# Set up mock methods for both regular and sharded subscriptions
pubsub.subscribe = MagicMock()
pubsub.unsubscribe = MagicMock()
pubsub.ssubscribe = MagicMock() # type: ignore[attr-defined]
pubsub.sunsubscribe = MagicMock() # type: ignore[attr-defined]
pubsub.get_message = MagicMock()
pubsub.get_sharded_message = MagicMock() # type: ignore[attr-defined]
pubsub.close = MagicMock()
return pubsub
@pytest.fixture
def subscription(self, subscription_params, mock_pubsub: MagicMock):
"""Create a subscription instance based on parameterized type."""
subscription_type, subscription_class = subscription_params
topic_name = f"test-{subscription_type}-topic"
subscription = subscription_class(
pubsub=mock_pubsub,
topic=topic_name,
)
yield subscription
subscription.close()
@pytest.fixture
def started_subscription(self, subscription):
"""Create a subscription that has been started."""
subscription._start_if_needed()
return subscription
# ==================== Initialization Tests ====================
def test_subscription_initialization(self, subscription, subscription_params):
"""Test that subscription is properly initialized."""
subscription_type, _ = subscription_params
expected_topic = f"test-{subscription_type}-topic"
assert subscription._pubsub is not None
assert subscription._topic == expected_topic
assert not subscription._closed.is_set()
assert subscription._dropped_count == 0
assert subscription._listener_thread is None
assert not subscription._started
def test_subscription_type(self, subscription, subscription_params):
"""Test that subscription returns correct type."""
subscription_type, _ = subscription_params
assert subscription._get_subscription_type() == subscription_type
# ==================== Lifecycle Tests ====================
def test_start_if_needed_first_call(self, subscription, subscription_params, mock_pubsub: MagicMock):
"""Test that _start_if_needed() properly starts subscription on first call."""
subscription_type, _ = subscription_params
subscription._start_if_needed()
if subscription_type == "regular":
mock_pubsub.subscribe.assert_called_once()
else:
mock_pubsub.ssubscribe.assert_called_once()
assert subscription._started is True
assert subscription._listener_thread is not None
def test_start_if_needed_subsequent_calls(self, started_subscription):
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
original_thread = started_subscription._listener_thread
started_subscription._start_if_needed()
# Should not create new thread
assert started_subscription._listener_thread is original_thread
def test_context_manager_usage(self, subscription, subscription_params, mock_pubsub: MagicMock):
"""Test that subscription works as context manager."""
subscription_type, _ = subscription_params
expected_topic = f"test-{subscription_type}-topic"
with subscription as sub:
assert sub is subscription
assert subscription._started is True
if subscription_type == "regular":
mock_pubsub.subscribe.assert_called_with(expected_topic)
else:
mock_pubsub.ssubscribe.assert_called_with(expected_topic)
def test_close_idempotent(self, subscription, subscription_params, mock_pubsub: MagicMock):
"""Test that close() is idempotent and can be called multiple times."""
subscription_type, _ = subscription_params
subscription._start_if_needed()
# Close multiple times
subscription.close()
subscription.close()
subscription.close()
# Should only cleanup once
if subscription_type == "regular":
mock_pubsub.unsubscribe.assert_called_once()
else:
mock_pubsub.sunsubscribe.assert_called_once()
mock_pubsub.close.assert_called_once()
assert subscription._pubsub is None
assert subscription._closed.is_set()
# ==================== Message Processing Tests ====================
def test_message_iterator_with_messages(self, started_subscription):
"""Test message iterator behavior with messages in queue."""
test_messages = [b"msg1", b"msg2", b"msg3"]
# Add messages to queue
for msg in test_messages:
started_subscription._queue.put_nowait(msg)
# Iterate through messages
iterator = iter(started_subscription)
received_messages = []
for msg in iterator:
received_messages.append(msg)
if len(received_messages) >= len(test_messages):
break
assert received_messages == test_messages
def test_message_iterator_when_closed(self, subscription, subscription_params):
"""Test that iterator raises error when subscription is closed."""
subscription_type, _ = subscription_params
subscription.close()
with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"):
iter(subscription)
# ==================== Message Enqueue Tests ====================
def test_enqueue_message_success(self, started_subscription):
"""Test successful message enqueue."""
payload = b"test message"
started_subscription._enqueue_message(payload)
assert started_subscription._queue.qsize() == 1
assert started_subscription._queue.get_nowait() == payload
def test_enqueue_message_when_closed(self, subscription):
"""Test message enqueue when subscription is closed."""
subscription.close()
payload = b"test message"
# Should not raise exception, but should not enqueue
subscription._enqueue_message(payload)
assert subscription._queue.empty()
def test_enqueue_message_with_full_queue(self, started_subscription):
"""Test message enqueue with full queue (dropping behavior)."""
# Fill the queue
for i in range(started_subscription._queue.maxsize):
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
# Try to enqueue new message (should drop oldest)
new_message = b"new_message"
started_subscription._enqueue_message(new_message)
# Should have dropped one message and added new one
assert started_subscription._dropped_count == 1
# New message should be in queue
messages = []
while not started_subscription._queue.empty():
messages.append(started_subscription._queue.get_nowait())
assert new_message in messages
# ==================== Message Type Tests ====================
def test_get_message_type(self, subscription, subscription_params):
"""Test that subscription returns correct message type."""
subscription_type, _ = subscription_params
expected_type = "message" if subscription_type == "regular" else "smessage"
assert subscription._get_message_type() == expected_type
# ==================== Error Handling Tests ====================
def test_start_if_needed_when_closed(self, subscription, subscription_params):
"""Test that _start_if_needed() raises error when subscription is closed."""
subscription_type, _ = subscription_params
subscription.close()
with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"):
subscription._start_if_needed()
def test_start_if_needed_when_cleaned_up(self, subscription, subscription_params):
"""Test that _start_if_needed() raises error when pubsub is None."""
subscription_type, _ = subscription_params
subscription._pubsub = None
with pytest.raises(
SubscriptionClosedError, match=f"The Redis {subscription_type} subscription has been cleaned up"
):
subscription._start_if_needed()
def test_iterator_after_close(self, subscription, subscription_params):
"""Test iterator behavior after close."""
subscription_type, _ = subscription_params
subscription.close()
with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"):
iter(subscription)
def test_start_after_close(self, subscription, subscription_params):
"""Test start attempts after close."""
subscription_type, _ = subscription_params
subscription.close()
with pytest.raises(SubscriptionClosedError, match=f"The Redis {subscription_type} subscription is closed"):
subscription._start_if_needed()
def test_pubsub_none_operations(self, subscription, subscription_params):
"""Test operations when pubsub is None."""
subscription_type, _ = subscription_params
subscription._pubsub = None
with pytest.raises(
SubscriptionClosedError, match=f"The Redis {subscription_type} subscription has been cleaned up"
):
subscription._start_if_needed()
# Close should still work
subscription.close() # Should not raise
def test_receive_on_closed_subscription(self, subscription, subscription_params):
"""Test receive method on closed subscription."""
subscription.close()
with pytest.raises(SubscriptionClosedError):
subscription.receive()
# ==================== Table-driven Tests ====================
@pytest.mark.parametrize(
"test_case",
[
SubscriptionTestCase(
name="basic_message",
buffer_size=5,
payload=b"hello world",
expected_messages=[b"hello world"],
description="Basic message publishing and receiving",
),
SubscriptionTestCase(
name="empty_message",
buffer_size=5,
payload=b"",
expected_messages=[b""],
description="Empty message handling",
),
SubscriptionTestCase(
name="large_message",
buffer_size=5,
payload=b"x" * 10000,
expected_messages=[b"x" * 10000],
description="Large message handling",
),
SubscriptionTestCase(
name="unicode_message",
buffer_size=5,
payload="你好世界".encode(),
expected_messages=["你好世界".encode()],
description="Unicode message handling",
),
],
)
def test_subscription_scenarios(
self, test_case: SubscriptionTestCase, subscription, subscription_params, mock_pubsub: MagicMock
):
"""Test various subscription scenarios using table-driven approach."""
subscription_type, _ = subscription_params
expected_topic = f"test-{subscription_type}-topic"
expected_message_type = "message" if subscription_type == "regular" else "smessage"
# Simulate receiving message
mock_message = {"type": expected_message_type, "channel": expected_topic, "data": test_case.payload}
if subscription_type == "regular":
mock_pubsub.get_message.return_value = mock_message
else:
mock_pubsub.get_sharded_message.return_value = mock_message
try:
with subscription:
# Wait for message processing
time.sleep(0.1)
# Collect received messages
received = []
for msg in subscription:
received.append(msg)
if len(received) >= len(test_case.expected_messages):
break
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
finally:
subscription.close()
# ==================== Concurrency Tests ====================
def test_concurrent_close_and_enqueue(self, started_subscription):
"""Test concurrent close and enqueue operations."""
errors = []
def close_subscription():
try:
time.sleep(0.05) # Small delay
started_subscription.close()
except Exception as e:
errors.append(e)
def enqueue_messages():
try:
for i in range(50):
started_subscription._enqueue_message(f"msg_{i}".encode())
time.sleep(0.001)
except Exception as e:
errors.append(e)
# Start threads
close_thread = threading.Thread(target=close_subscription)
enqueue_thread = threading.Thread(target=enqueue_messages)
close_thread.start()
enqueue_thread.start()
# Wait for completion
close_thread.join(timeout=2.0)
enqueue_thread.join(timeout=2.0)
# Should not have any errors (operations should be safe)
assert len(errors) == 0