diff --git a/pythclient/price_feeds.py b/pythclient/price_feeds.py index 498e80d..2dc337f 100644 --- a/pythclient/price_feeds.py +++ b/pythclient/price_feeds.py @@ -1,8 +1,7 @@ import base64 import binascii -import struct from struct import unpack -from typing import Any, Dict, List, Optional +from typing import List, Literal, Optional, Union, cast from Crypto.Hash import keccak from loguru import logger @@ -164,7 +163,7 @@ def __str__(self): # Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/encoding.ts#L24 -def encode_vaa_for_chain(vaa, vaa_format, buffer=False): +def encode_vaa_for_chain(vaa: str, vaa_format: str, buffer=False) -> Union[bytes, str]: # check if vaa is already in vaa_format if isinstance(vaa, str): if vaa_format == DEFAULT_VAA_ENCODING: @@ -197,7 +196,7 @@ def encode_vaa_for_chain(vaa, vaa_format, buffer=False): # Referenced from https://github.com/wormhole-foundation/wormhole/blob/main/sdk/js/src/vaa/wormhole.ts#L26-L56 def parse_vaa(vaa, encoding): - vaa = encode_vaa_for_chain(vaa, encoding, buffer=True) + vaa = cast(bytes, encode_vaa_for_chain(vaa, encoding, buffer=True)) num_signers = vaa[5] sig_length = 66 @@ -284,7 +283,7 @@ def parse_batch_price_attestation(bytes_): offset += 2 price_attestations = [] - for i in range(batch_len): + for _ in range(batch_len): price_attestations.append( parse_price_attestation(bytes_[offset : offset + attestation_size]) ) @@ -401,13 +400,13 @@ def is_accumulator_update(vaa, encoding=DEFAULT_VAA_ENCODING) -> bool: Returns: bool: True if the VAA is an accumulator update, False otherwise. """ - if encode_vaa_for_chain(vaa, encoding, buffer=True)[:4].hex() == ACCUMULATOR_MAGIC: + if cast(bytes, encode_vaa_for_chain(vaa, encoding, buffer=True))[:4].hex() == ACCUMULATOR_MAGIC: return True return False # Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/rest.ts#L139 -def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]: +def vaa_to_price_infos(vaa, encoding: Literal["hex", "base64"] = DEFAULT_VAA_ENCODING) -> Optional[List[PriceInfo]]: if is_accumulator_update(vaa, encoding): return extract_price_info_from_accumulator_update(vaa, encoding) parsed_vaa = parse_vaa(vaa, encoding) @@ -425,7 +424,7 @@ def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]: return price_infos -def vaa_to_price_info(id, vaa, encoding=DEFAULT_VAA_ENCODING) -> Optional[PriceInfo]: +def vaa_to_price_info(id: str, vaa: str, encoding: Literal["hex", "base64"] = DEFAULT_VAA_ENCODING) -> Optional[PriceInfo]: """ This function retrieves a specific PriceInfo object from a given VAA. @@ -502,14 +501,21 @@ def price_attestation_to_price_feed(price_attestation): # Referenced from https://github.com/pyth-network/pyth-crosschain/blob/1a00598334e52fc5faf967eb1170d7fc23ad828b/price_service/server/src/rest.ts#L137 def extract_price_info_from_accumulator_update( - update_data, encoding -) -> Optional[Dict[str, Any]]: + update_data: str, + encoding: Literal["hex", "base64"] +) -> Optional[List[PriceInfo]]: parsed_update_data = parse_accumulator_update(update_data, encoding) + if parsed_update_data is None: + return None + vaa_buffer = parsed_update_data.vaa if encoding == "hex": vaa_str = vaa_buffer.hex() elif encoding == "base64": vaa_str = base64.b64encode(vaa_buffer).decode("ascii") + else: + raise ValueError(f"Invalid encoding: {encoding}") + parsed_vaa = parse_vaa(vaa_str, encoding) price_infos = [] for update in parsed_update_data.updates: @@ -581,7 +587,6 @@ def extract_price_info_from_accumulator_update( return price_infos - def compress_accumulator_update(update_data_list, encoding) -> List[str]: """ This function compresses a list of accumulator update data by combining those with the same VAA. @@ -593,17 +598,21 @@ def compress_accumulator_update(update_data_list, encoding) -> List[str]: Returns: List[str]: A list of serialized accumulator update data. Each item in the list is a hexadecimal string representing - an accumulator update data. The updates with the same VAA are combined and split into chunks of 255 updates each. + an accumulator update data. The updates with the same VAA payload are combined and split into chunks of 255 updates each. """ parsed_data_dict = {} # Use a dictionary for O(1) lookup # Combine the ones with the same VAA to a list for update_data in update_data_list: parsed_update_data = parse_accumulator_update(update_data, encoding) - vaa = parsed_update_data.vaa - if vaa not in parsed_data_dict: - parsed_data_dict[vaa] = [] - parsed_data_dict[vaa].append(parsed_update_data) + if parsed_update_data is None: + raise ValueError(f"Invalid accumulator update data: {update_data}") + + payload = parse_vaa(parsed_update_data.vaa.hex(), "hex")["payload"] + + if payload not in parsed_data_dict: + parsed_data_dict[payload] = [] + parsed_data_dict[payload].append(parsed_update_data) parsed_data_list = list(parsed_data_dict.values()) # Combines accumulator update data with the same VAA into a single dictionary @@ -698,7 +707,7 @@ def serialize_accumulator_update(data, encoding): return base64.b64encode(serialized_data).decode("ascii") -def parse_accumulator_update(update_data, encoding): +def parse_accumulator_update(update_data: str, encoding: str) -> Optional[AccumulatorUpdate]: """ This function parses an accumulator update data. @@ -724,7 +733,8 @@ def parse_accumulator_update(update_data, encoding): If the update type is not 0, the function logs an info message and returns None. """ - encoded_update_data = encode_vaa_for_chain(update_data, encoding, buffer=True) + encoded_update_data = cast(bytes, encode_vaa_for_chain(update_data, encoding, buffer=True)) + offset = 0 magic = encoded_update_data[offset : offset + 4] offset += 4 diff --git a/setup.py b/setup.py index 7a33f0e..fc36bff 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='pythclient', - version='0.1.14', + version='0.1.15', packages=['pythclient'], author='Pyth Developers', author_email='contact@pyth.network',