Skip to content

Commit b0efe55

Browse files
committed
Check the state when sending
Raising a LocalProtocolError if the state does not allow for the event to be sent.
1 parent 84b4f32 commit b0efe55

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

src/wsproto/connection.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,25 @@ def state(self) -> ConnectionState:
8888

8989
def send(self, event: Event) -> bytes:
9090
data = b""
91-
if isinstance(event, Message):
91+
if isinstance(event, Message) and self.state == ConnectionState.OPEN:
9292
data += self._proto.send_data(event.data, event.message_finished)
93-
elif isinstance(event, Ping):
93+
elif isinstance(event, Ping) and self.state == ConnectionState.OPEN:
9494
data += self._proto.ping(event.payload)
95-
elif isinstance(event, Pong):
95+
elif isinstance(event, Pong) and self.state == ConnectionState.OPEN:
9696
data += self._proto.pong(event.payload)
97-
elif isinstance(event, CloseConnection):
98-
if self.state not in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
99-
raise LocalProtocolError(
100-
"Connection cannot be closed in state %s" % self.state
101-
)
97+
elif isinstance(event, CloseConnection) and self.state in {
98+
ConnectionState.OPEN,
99+
ConnectionState.REMOTE_CLOSING,
100+
}:
102101
data += self._proto.close(event.code, event.reason)
103102
if self.state == ConnectionState.REMOTE_CLOSING:
104103
self._state = ConnectionState.CLOSED
105104
else:
106105
self._state = ConnectionState.LOCAL_CLOSING
107106
else:
108-
raise LocalProtocolError(f"Event {event} cannot be sent.")
107+
raise LocalProtocolError(
108+
f"Event {event} cannot be sent in state {self.state}."
109+
)
109110
return data
110111

111112
def receive_data(self, data: Optional[bytes]) -> None:

test/test_connection.py

+7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def test_close_whilst_closing() -> None:
8888
client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))
8989

9090

91+
def test_send_after_close() -> None:
92+
client = Connection(CLIENT)
93+
client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))
94+
with pytest.raises(LocalProtocolError):
95+
client.send(TextMessage(data="", message_finished=True))
96+
97+
9198
@pytest.mark.parametrize("client_sends", [True, False])
9299
def test_ping_pong(client_sends: bool) -> None:
93100
client = Connection(CLIENT)

0 commit comments

Comments
 (0)