Files
ayako/advanced_features.py
2026-05-01 15:13:02 +03:00

326 lines
9.8 KiB
Python

"""
Advanced features for VNDB Telegram Bot
Includes pagination, caching, and rate limiting
"""
import asyncio
import time
from typing import Dict, List, Any, Optional, Callable
from collections import defaultdict
from functools import wraps
import logging
logger = logging.getLogger(__name__)
class RateLimiter:
"""Rate limiter for API requests"""
def __init__(self, max_requests: int = 200, window_seconds: int = 300):
"""
Initialize rate limiter
Args:
max_requests: Maximum requests allowed in window
window_seconds: Time window in seconds (default 5 minutes = 300 seconds)
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = []
def is_allowed(self) -> bool:
"""Check if a request is allowed"""
now = time.time()
# Remove old requests outside the window
self.requests = [req_time for req_time in self.requests
if now - req_time < self.window_seconds]
# Check if we can make another request
if len(self.requests) < self.max_requests:
self.requests.append(now)
return True
return False
def wait_if_needed(self) -> None:
"""Wait if rate limit is reached"""
if not self.is_allowed():
if self.requests:
oldest = self.requests[0]
wait_time = self.window_seconds - (time.time() - oldest)
if wait_time > 0:
logger.warning(f"Rate limit reached, waiting {wait_time:.1f}s")
time.sleep(wait_time)
self.is_allowed()
class SimpleCache:
"""Simple in-memory cache for API responses"""
def __init__(self, ttl_seconds: int = 300):
"""
Initialize cache
Args:
ttl_seconds: Time to live for cached items
"""
self.ttl_seconds = ttl_seconds
self.cache: Dict[str, tuple] = {}
def _make_key(self, endpoint: str, params: Dict[str, Any]) -> str:
"""Create cache key from endpoint and parameters"""
params_str = str(sorted(params.items()))
return f"{endpoint}:{params_str}"
def get(self, endpoint: str, params: Dict[str, Any]) -> Optional[Any]:
"""Get item from cache"""
key = self._make_key(endpoint, params)
if key not in self.cache:
return None
value, timestamp = self.cache[key]
# Check if expired
if time.time() - timestamp > self.ttl_seconds:
del self.cache[key]
return None
logger.debug(f"Cache hit for {key}")
return value
def set(self, endpoint: str, params: Dict[str, Any], value: Any) -> None:
"""Set item in cache"""
key = self._make_key(endpoint, params)
self.cache[key] = (value, time.time())
logger.debug(f"Cached {key}")
def clear(self) -> None:
"""Clear all cache"""
self.cache.clear()
def stats(self) -> Dict[str, int]:
"""Get cache statistics"""
now = time.time()
expired = sum(
1 for _, (_, timestamp) in self.cache.items()
if now - timestamp > self.ttl_seconds
)
return {
"total_items": len(self.cache),
"expired_items": expired,
}
class Paginator:
"""Handle pagination for search results"""
def __init__(self, items: List[Dict[str, Any]], items_per_page: int = 5):
"""
Initialize paginator
Args:
items: List of items to paginate
items_per_page: Number of items per page
"""
self.items = items
self.items_per_page = items_per_page
self.current_page = 1
@property
def total_pages(self) -> int:
"""Get total number of pages"""
return (len(self.items) + self.items_per_page - 1) // self.items_per_page
@property
def current_items(self) -> List[Dict[str, Any]]:
"""Get items for current page"""
start = (self.current_page - 1) * self.items_per_page
end = start + self.items_per_page
return self.items[start:end]
def next_page(self) -> bool:
"""Go to next page"""
if self.current_page < self.total_pages:
self.current_page += 1
return True
return False
def prev_page(self) -> bool:
"""Go to previous page"""
if self.current_page > 1:
self.current_page -= 1
return True
return False
def goto_page(self, page: int) -> bool:
"""Go to specific page"""
if 1 <= page <= self.total_pages:
self.current_page = page
return True
return False
def page_info(self) -> str:
"""Get page information string"""
return f"Страница {self.current_page}/{self.total_pages}"
class UserSession:
"""Manage user session data"""
def __init__(self, user_id: int):
"""Initialize session"""
self.user_id = user_id
self.data: Dict[str, Any] = {}
self.created_at = time.time()
self.last_activity = time.time()
def set(self, key: str, value: Any) -> None:
"""Set session data"""
self.data[key] = value
self.last_activity = time.time()
def get(self, key: str, default: Any = None) -> Any:
"""Get session data"""
self.last_activity = time.time()
return self.data.get(key, default)
def update_activity(self) -> None:
"""Update last activity time"""
self.last_activity = time.time()
def is_idle(self, timeout_seconds: int = 1800) -> bool:
"""Check if session is idle"""
return time.time() - self.last_activity > timeout_seconds
def clear(self) -> None:
"""Clear session data"""
self.data.clear()
class SessionManager:
"""Manage user sessions"""
def __init__(self, idle_timeout_seconds: int = 1800):
"""
Initialize session manager
Args:
idle_timeout_seconds: Timeout for idle sessions
"""
self.sessions: Dict[int, UserSession] = {}
self.idle_timeout = idle_timeout_seconds
def get_session(self, user_id: int) -> UserSession:
"""Get or create user session"""
if user_id not in self.sessions:
self.sessions[user_id] = UserSession(user_id)
else:
self.sessions[user_id].update_activity()
return self.sessions[user_id]
def cleanup_idle_sessions(self) -> int:
"""Remove idle sessions"""
user_ids_to_remove = [
user_id for user_id, session in self.sessions.items()
if session.is_idle(self.idle_timeout)
]
for user_id in user_ids_to_remove:
del self.sessions[user_id]
logger.info(f"Cleaned up {len(user_ids_to_remove)} idle sessions")
return len(user_ids_to_remove)
def stats(self) -> Dict[str, Any]:
"""Get session statistics"""
return {
"active_sessions": len(self.sessions),
"total_users": len(self.sessions),
}
class RequestLogger:
"""Log API requests for debugging"""
def __init__(self):
"""Initialize request logger"""
self.requests: List[Dict[str, Any]] = []
self.max_history = 100
def log_request(
self,
endpoint: str,
method: str,
params: Optional[Dict[str, Any]] = None,
response_time: float = 0,
status_code: int = 0,
error: Optional[str] = None,
) -> None:
"""Log an API request"""
request_log = {
"timestamp": time.time(),
"endpoint": endpoint,
"method": method,
"params": params,
"response_time": response_time,
"status_code": status_code,
"error": error,
}
self.requests.append(request_log)
# Keep only recent requests
if len(self.requests) > self.max_history:
self.requests = self.requests[-self.max_history:]
def get_stats(self) -> Dict[str, Any]:
"""Get request statistics"""
if not self.requests:
return {"requests_logged": 0}
total_time = sum(r["response_time"] for r in self.requests)
avg_time = total_time / len(self.requests) if self.requests else 0
errors = sum(1 for r in self.requests if r["error"])
return {
"requests_logged": len(self.requests),
"total_time": total_time,
"average_time": avg_time,
"errors": errors,
"success_rate": (len(self.requests) - errors) / len(self.requests) * 100,
}
def rate_limit(limiter: RateLimiter):
"""Decorator for rate limiting"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
limiter.wait_if_needed()
return await func(*args, **kwargs)
return wrapper
return decorator
def with_cache(cache: SimpleCache, ttl: int = 300):
"""Decorator for caching"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, *args, endpoint: str = "", **kwargs):
# Try to get from cache
cached = cache.get(endpoint, kwargs)
if cached:
return cached
# Call function
result = await func(self, *args, endpoint=endpoint, **kwargs)
# Cache result
cache.set(endpoint, kwargs, result)
return result
return wrapper
return decorator