diff --git a/src/cryptocom/exchange/api.py b/src/cryptocom/exchange/api.py index ac1cd9e..8bb339d 100644 --- a/src/cryptocom/exchange/api.py +++ b/src/cryptocom/exchange/api.py @@ -7,6 +7,7 @@ import asyncio import hashlib from urllib.parse import urljoin +from .rate_limiter import RateLimiter import aiohttp @@ -24,16 +25,33 @@ class ApiProvider: auth_required=True, timeout=25, retries=6, root_url='https://api.crypto.com/v2/', ws_root_url='wss://stream.crypto.com/v2/', logger=None): + self.api_key = api_key self.api_secret = api_secret self.root_url = root_url self.ws_root_url = ws_root_url self.timeout = timeout self.retries = retries + self.limits = { + # method: (req_limit, period) + 'private/create-order': (15, 0.1), + 'private/cancel-order': (15, 0.1), + 'private/cancel-all-orders': (15, 0.1), + 'private/margin/create-order': (15, 0.1), + 'private/margin/cancel-order': (15, 0.1), + 'private/margin/cancel-all-orders': (15, 0.1), + + 'private/get-order-detail': (30, 0.1), + 'private/margin/get-order-detail': (30, 0.1), + + 'private/get-trades': (1, 1), + 'private/margin/get-trades': (1, 1), + 'private/get-order-history': (1, 1), + 'private/margin/get-order-history': (1, 1) + } # NOTE: do not change this, due to crypto.com rate-limits # TODO: add more strict settings, req/per second or milliseconds - self.semaphore = asyncio.Semaphore(20) if not auth_required: return @@ -79,22 +97,37 @@ class ApiProvider: async def request(self, method, path, params=None, data=None, sign=False): original_data = data timeout = aiohttp.ClientTimeout(total=self.timeout) + request_type = path.split('/')[0] + + if not (path in self.limits.keys()) and request_type == 'public': + rate_limit, period = 100, 1 + elif not (path in self.limits.keys()) and request_type == 'private': + rate_limit, period = 3, 0.1 + elif not (path in self.limits.keys()): + raise ApiError(f'Wrong path: {path}') + else: + rate_limit, period = self.limits[path] + + rate_limiter = RateLimiter(rate_limit=rate_limit, period=period, + concurrency_limit=1) + for count in range(self.retries + 1): if sign: data = self._sign(path, original_data) try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with self.semaphore: - resp = await session.request( - method, urljoin(self.root_url, path), - params=params, json=data, - headers={'content-type': 'application/json'} - ) - resp_json = await resp.json() - if resp.status != 200: - raise ApiError( - f"Error: {resp_json}. " - f"Status: {resp.status}. Json params: {data}") + async with rate_limiter: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with rate_limiter.throttle(): + resp = await session.request( + method, urljoin(self.root_url, path), + params=params, json=data, + headers={'content-type': 'application/json'} + ) + resp_json = await resp.json() + if resp.status != 200: + raise ApiError( + f"Error: {resp_json}. " + f"Status: {resp.status}. Json params: {data}") except aiohttp.ClientConnectorError: raise ApiError(f"Cannot connect to host {self.root_url}") except asyncio.TimeoutError: diff --git a/src/cryptocom/exchange/rate_limiter.py b/src/cryptocom/exchange/rate_limiter.py new file mode 100644 index 0000000..dc58df6 --- /dev/null +++ b/src/cryptocom/exchange/rate_limiter.py @@ -0,0 +1,98 @@ +import asyncio +import math +import time +import traceback + +from contextlib import asynccontextmanager + + +class RateLimiter: + def __init__(self, + rate_limit: int, + period: float or int, # takes seconds + concurrency_limit: int) -> None: + if not rate_limit or rate_limit < 1: + raise ValueError('rate limit must be non zero positive number') + if not concurrency_limit or concurrency_limit < 1: + raise ValueError('concurrent limit must be non zero positive number') + + self.rate_limit = rate_limit + self.period = period + self.tokens_queue = asyncio.Queue(rate_limit) + self.tokens_consumer_task = asyncio.create_task(self.consume_tokens()) + self.semaphore = asyncio.Semaphore(concurrency_limit) + + async def add_token(self) -> None: + await self.tokens_queue.put(1) + return None + + async def consume_tokens(self): + try: + consumption_rate = self.period / self.rate_limit + last_consumption_time = 0 + + while True: + if self.tokens_queue.empty(): + await asyncio.sleep(consumption_rate) + continue + + current_consumption_time = time.monotonic() + total_tokens = self.tokens_queue.qsize() + tokens_to_consume = self.get_tokens_amount_to_consume( + consumption_rate, + current_consumption_time, + last_consumption_time, + total_tokens + ) + + for _ in range(0, tokens_to_consume): + self.tokens_queue.get_nowait() + + last_consumption_time = time.monotonic() + + await asyncio.sleep(consumption_rate) + except asyncio.CancelledError: + raise + except Exception as e: + raise + + @staticmethod + def get_tokens_amount_to_consume(consumption_rate, current_consumption_time, + last_consumption_time, total_tokens): + time_from_last_consumption = current_consumption_time - last_consumption_time + calculated_tokens_to_consume = math.floor(time_from_last_consumption / consumption_rate) + tokens_to_consume = min(total_tokens, calculated_tokens_to_consume) + + return tokens_to_consume + + @asynccontextmanager + async def throttle(self): + await self.semaphore.acquire() + await self.add_token() + try: + yield + finally: + self.semaphore.release() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type: + pass + # print(traceback.format_exc()) + + await self.close() + + async def close(self) -> None: + if self.tokens_consumer_task and not self.tokens_consumer_task.cancelled(): + try: + self.tokens_consumer_task.cancel() + await self.tokens_consumer_task + except asyncio.CancelledError: + # print(traceback.format_exc()) + pass + + except Exception as e: + # print(traceback.format_exc()) + raise diff --git a/tests/test_api.py b/tests/test_api.py index 541cc4e..aff2ce8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -55,3 +55,18 @@ async def test_wrong_api_response(): api = cro.ApiProvider(auth_required=False) with pytest.raises(cro.ApiError): await api.post('account') + + +# @pytest.mark.asyncio +# async def test_api_rate_limits(): +# api = cro.ApiProvider(from_env=True) +# account = cro.Account(from_env=True) + +# for _ in range(0, 100): +# await account.get_balance() + +# for _ in range(0, 100): +# await account.get_orders_history(cro.pairs.CRO_USDT, page_size=50) + +# for _ in range(0, 100): +# await api.get('public/get-ticker')