Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pydantic support to rtc.VideoFrame & rtc.AudioFrame #348

Merged
merged 8 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check-types.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: python -m pip install --upgrade mypy

- name: Install packages
run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc
run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc pydantic

- name: Check Types
run: python -m mypy --install-type --non-interactive -p 'livekit-protocol' -p 'livekit-api' -p 'livekit-rtc'
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ jobs:
- name: Run tests
run: |
python3 ./livekit-rtc/rust-sdks/download_ffi.py --output livekit-rtc/livekit/rtc/resources
pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc
pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc pydantic
pytest . --ignore=livekit-rtc/rust-sdks
61 changes: 60 additions & 1 deletion livekit-rtc/livekit/rtc/audio_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._proto import audio_frame_pb2 as proto_audio
from ._proto import ffi_pb2 as proto_ffi
from ._utils import get_address
from typing import Union
from typing import Any, Union


class AudioFrame:
Expand Down Expand Up @@ -55,6 +55,10 @@ def __init__(
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
)

if len(data) % ctypes.sizeof(ctypes.c_int16) != 0:
# can happen if data is bigger than needed
raise ValueError("data length must be a multiple of sizeof(int16)")

self._data = bytearray(data)
self._sample_rate = sample_rate
self._num_channels = num_channels
Expand Down Expand Up @@ -197,3 +201,58 @@ def __repr__(self) -> str:
f"samples_per_channel={self.samples_per_channel}, "
f"duration={self.duration:.3f})"
)

@classmethod
def __get_pydantic_core_schema__(cls, *_: Any):
from pydantic_core import core_schema
import base64

def validate_audio_frame(value: Any) -> "AudioFrame":
if isinstance(value, AudioFrame):
return value

if isinstance(value, tuple):
value = value[0]

if isinstance(value, dict):
return AudioFrame(
data=base64.b64decode(value["data"]),
sample_rate=value["sample_rate"],
num_channels=value["num_channels"],
samples_per_channel=value["samples_per_channel"],
)

raise TypeError("Invalid type for AudioFrame")

return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema(
[
core_schema.model_fields_schema(
{
"data": core_schema.model_field(core_schema.str_schema()),
"sample_rate": core_schema.model_field(
core_schema.int_schema()
),
"num_channels": core_schema.model_field(
core_schema.int_schema()
),
"samples_per_channel": core_schema.model_field(
core_schema.int_schema()
),
},
),
core_schema.no_info_plain_validator_function(validate_audio_frame),
]
),
python_schema=core_schema.no_info_plain_validator_function(
validate_audio_frame
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: {
"data": base64.b64encode(instance.data).decode("utf-8"),
"sample_rate": instance.sample_rate,
"num_channels": instance.num_channels,
"samples_per_channel": instance.samples_per_channel,
}
),
)
51 changes: 51 additions & 0 deletions livekit-rtc/livekit/rtc/video_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ._ffi_client import FfiClient, FfiHandle
from ._utils import get_address

from typing import Any


class VideoFrame:
"""
Expand Down Expand Up @@ -203,6 +205,55 @@ def convert(
def __repr__(self) -> str:
return f"rtc.VideoFrame(width={self.width}, height={self.height}, type={self.type})"

@classmethod
def __get_pydantic_core_schema__(cls, *_: Any):
from pydantic_core import core_schema
import base64

def validate_video_frame(value: Any) -> "VideoFrame":
if isinstance(value, VideoFrame):
return value

if isinstance(value, tuple):
value = value[0]

if isinstance(value, dict):
return VideoFrame(
width=value["width"],
height=value["height"],
type=proto_video.VideoBufferType.ValueType(value["type"]),
data=base64.b64decode(value["data"]),
)

raise TypeError("Invalid type for VideoFrame")

return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema(
[
core_schema.model_fields_schema(
{
"width": core_schema.model_field(core_schema.int_schema()),
"height": core_schema.model_field(core_schema.int_schema()),
"type": core_schema.model_field(core_schema.int_schema()),
"data": core_schema.model_field(core_schema.str_schema()),
},
),
core_schema.no_info_plain_validator_function(validate_video_frame),
]
),
python_schema=core_schema.no_info_plain_validator_function(
validate_video_frame
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: {
"width": instance.width,
"height": instance.height,
"type": instance.type,
"data": base64.b64encode(instance.data).decode("utf-8"),
}
),
)


def _component_info(
data_ptr: int, stride: int, size: int
Expand Down
Loading