326 lines
9.8 KiB
Python
326 lines
9.8 KiB
Python
"""
|
|
Advanced features for VNDB Telegram Bot
|
|
Includes pagination, caching, and rate limiting
|
|
"""
|
|
import asyncio
|
|
import time
|
|
from typing import Dict, List, Any, Optional, Callable
|
|
from collections import defaultdict
|
|
from functools import wraps
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Rate limiter for API requests"""
|
|
|
|
def __init__(self, max_requests: int = 200, window_seconds: int = 300):
|
|
"""
|
|
Initialize rate limiter
|
|
|
|
Args:
|
|
max_requests: Maximum requests allowed in window
|
|
window_seconds: Time window in seconds (default 5 minutes = 300 seconds)
|
|
"""
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
self.requests = []
|
|
|
|
def is_allowed(self) -> bool:
|
|
"""Check if a request is allowed"""
|
|
now = time.time()
|
|
|
|
# Remove old requests outside the window
|
|
self.requests = [req_time for req_time in self.requests
|
|
if now - req_time < self.window_seconds]
|
|
|
|
# Check if we can make another request
|
|
if len(self.requests) < self.max_requests:
|
|
self.requests.append(now)
|
|
return True
|
|
|
|
return False
|
|
|
|
def wait_if_needed(self) -> None:
|
|
"""Wait if rate limit is reached"""
|
|
if not self.is_allowed():
|
|
if self.requests:
|
|
oldest = self.requests[0]
|
|
wait_time = self.window_seconds - (time.time() - oldest)
|
|
if wait_time > 0:
|
|
logger.warning(f"Rate limit reached, waiting {wait_time:.1f}s")
|
|
time.sleep(wait_time)
|
|
self.is_allowed()
|
|
|
|
|
|
class SimpleCache:
|
|
"""Simple in-memory cache for API responses"""
|
|
|
|
def __init__(self, ttl_seconds: int = 300):
|
|
"""
|
|
Initialize cache
|
|
|
|
Args:
|
|
ttl_seconds: Time to live for cached items
|
|
"""
|
|
self.ttl_seconds = ttl_seconds
|
|
self.cache: Dict[str, tuple] = {}
|
|
|
|
def _make_key(self, endpoint: str, params: Dict[str, Any]) -> str:
|
|
"""Create cache key from endpoint and parameters"""
|
|
params_str = str(sorted(params.items()))
|
|
return f"{endpoint}:{params_str}"
|
|
|
|
def get(self, endpoint: str, params: Dict[str, Any]) -> Optional[Any]:
|
|
"""Get item from cache"""
|
|
key = self._make_key(endpoint, params)
|
|
|
|
if key not in self.cache:
|
|
return None
|
|
|
|
value, timestamp = self.cache[key]
|
|
|
|
# Check if expired
|
|
if time.time() - timestamp > self.ttl_seconds:
|
|
del self.cache[key]
|
|
return None
|
|
|
|
logger.debug(f"Cache hit for {key}")
|
|
return value
|
|
|
|
def set(self, endpoint: str, params: Dict[str, Any], value: Any) -> None:
|
|
"""Set item in cache"""
|
|
key = self._make_key(endpoint, params)
|
|
self.cache[key] = (value, time.time())
|
|
logger.debug(f"Cached {key}")
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all cache"""
|
|
self.cache.clear()
|
|
|
|
def stats(self) -> Dict[str, int]:
|
|
"""Get cache statistics"""
|
|
now = time.time()
|
|
expired = sum(
|
|
1 for _, (_, timestamp) in self.cache.items()
|
|
if now - timestamp > self.ttl_seconds
|
|
)
|
|
return {
|
|
"total_items": len(self.cache),
|
|
"expired_items": expired,
|
|
}
|
|
|
|
|
|
class Paginator:
|
|
"""Handle pagination for search results"""
|
|
|
|
def __init__(self, items: List[Dict[str, Any]], items_per_page: int = 5):
|
|
"""
|
|
Initialize paginator
|
|
|
|
Args:
|
|
items: List of items to paginate
|
|
items_per_page: Number of items per page
|
|
"""
|
|
self.items = items
|
|
self.items_per_page = items_per_page
|
|
self.current_page = 1
|
|
|
|
@property
|
|
def total_pages(self) -> int:
|
|
"""Get total number of pages"""
|
|
return (len(self.items) + self.items_per_page - 1) // self.items_per_page
|
|
|
|
@property
|
|
def current_items(self) -> List[Dict[str, Any]]:
|
|
"""Get items for current page"""
|
|
start = (self.current_page - 1) * self.items_per_page
|
|
end = start + self.items_per_page
|
|
return self.items[start:end]
|
|
|
|
def next_page(self) -> bool:
|
|
"""Go to next page"""
|
|
if self.current_page < self.total_pages:
|
|
self.current_page += 1
|
|
return True
|
|
return False
|
|
|
|
def prev_page(self) -> bool:
|
|
"""Go to previous page"""
|
|
if self.current_page > 1:
|
|
self.current_page -= 1
|
|
return True
|
|
return False
|
|
|
|
def goto_page(self, page: int) -> bool:
|
|
"""Go to specific page"""
|
|
if 1 <= page <= self.total_pages:
|
|
self.current_page = page
|
|
return True
|
|
return False
|
|
|
|
def page_info(self) -> str:
|
|
"""Get page information string"""
|
|
return f"Страница {self.current_page}/{self.total_pages}"
|
|
|
|
|
|
class UserSession:
|
|
"""Manage user session data"""
|
|
|
|
def __init__(self, user_id: int):
|
|
"""Initialize session"""
|
|
self.user_id = user_id
|
|
self.data: Dict[str, Any] = {}
|
|
self.created_at = time.time()
|
|
self.last_activity = time.time()
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
"""Set session data"""
|
|
self.data[key] = value
|
|
self.last_activity = time.time()
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
"""Get session data"""
|
|
self.last_activity = time.time()
|
|
return self.data.get(key, default)
|
|
|
|
def update_activity(self) -> None:
|
|
"""Update last activity time"""
|
|
self.last_activity = time.time()
|
|
|
|
def is_idle(self, timeout_seconds: int = 1800) -> bool:
|
|
"""Check if session is idle"""
|
|
return time.time() - self.last_activity > timeout_seconds
|
|
|
|
def clear(self) -> None:
|
|
"""Clear session data"""
|
|
self.data.clear()
|
|
|
|
|
|
class SessionManager:
|
|
"""Manage user sessions"""
|
|
|
|
def __init__(self, idle_timeout_seconds: int = 1800):
|
|
"""
|
|
Initialize session manager
|
|
|
|
Args:
|
|
idle_timeout_seconds: Timeout for idle sessions
|
|
"""
|
|
self.sessions: Dict[int, UserSession] = {}
|
|
self.idle_timeout = idle_timeout_seconds
|
|
|
|
def get_session(self, user_id: int) -> UserSession:
|
|
"""Get or create user session"""
|
|
if user_id not in self.sessions:
|
|
self.sessions[user_id] = UserSession(user_id)
|
|
else:
|
|
self.sessions[user_id].update_activity()
|
|
|
|
return self.sessions[user_id]
|
|
|
|
def cleanup_idle_sessions(self) -> int:
|
|
"""Remove idle sessions"""
|
|
user_ids_to_remove = [
|
|
user_id for user_id, session in self.sessions.items()
|
|
if session.is_idle(self.idle_timeout)
|
|
]
|
|
|
|
for user_id in user_ids_to_remove:
|
|
del self.sessions[user_id]
|
|
|
|
logger.info(f"Cleaned up {len(user_ids_to_remove)} idle sessions")
|
|
return len(user_ids_to_remove)
|
|
|
|
def stats(self) -> Dict[str, Any]:
|
|
"""Get session statistics"""
|
|
return {
|
|
"active_sessions": len(self.sessions),
|
|
"total_users": len(self.sessions),
|
|
}
|
|
|
|
|
|
class RequestLogger:
|
|
"""Log API requests for debugging"""
|
|
|
|
def __init__(self):
|
|
"""Initialize request logger"""
|
|
self.requests: List[Dict[str, Any]] = []
|
|
self.max_history = 100
|
|
|
|
def log_request(
|
|
self,
|
|
endpoint: str,
|
|
method: str,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
response_time: float = 0,
|
|
status_code: int = 0,
|
|
error: Optional[str] = None,
|
|
) -> None:
|
|
"""Log an API request"""
|
|
request_log = {
|
|
"timestamp": time.time(),
|
|
"endpoint": endpoint,
|
|
"method": method,
|
|
"params": params,
|
|
"response_time": response_time,
|
|
"status_code": status_code,
|
|
"error": error,
|
|
}
|
|
|
|
self.requests.append(request_log)
|
|
|
|
# Keep only recent requests
|
|
if len(self.requests) > self.max_history:
|
|
self.requests = self.requests[-self.max_history:]
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get request statistics"""
|
|
if not self.requests:
|
|
return {"requests_logged": 0}
|
|
|
|
total_time = sum(r["response_time"] for r in self.requests)
|
|
avg_time = total_time / len(self.requests) if self.requests else 0
|
|
errors = sum(1 for r in self.requests if r["error"])
|
|
|
|
return {
|
|
"requests_logged": len(self.requests),
|
|
"total_time": total_time,
|
|
"average_time": avg_time,
|
|
"errors": errors,
|
|
"success_rate": (len(self.requests) - errors) / len(self.requests) * 100,
|
|
}
|
|
|
|
|
|
def rate_limit(limiter: RateLimiter):
|
|
"""Decorator for rate limiting"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
limiter.wait_if_needed()
|
|
return await func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def with_cache(cache: SimpleCache, ttl: int = 300):
|
|
"""Decorator for caching"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(self, *args, endpoint: str = "", **kwargs):
|
|
# Try to get from cache
|
|
cached = cache.get(endpoint, kwargs)
|
|
if cached:
|
|
return cached
|
|
|
|
# Call function
|
|
result = await func(self, *args, endpoint=endpoint, **kwargs)
|
|
|
|
# Cache result
|
|
cache.set(endpoint, kwargs, result)
|
|
|
|
return result
|
|
return wrapper
|
|
return decorator
|