From 45c5ed638956dd84ba8c5b7a352ed2d0b731b4b7 Mon Sep 17 00:00:00 2001 From: grassblock Date: Wed, 6 Aug 2025 11:12:45 +0800 Subject: [PATCH] fix: query_hit returns all 0 and permanently store data --- .gitignore | 3 +- core/middleware/rikki.py | 65 ++++++++++++++++++++++++++++++++++++++-- core/rikki_hit.py | 23 ++++++++++++-- 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 83100c6..42c3465 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .venv/ secrets/ message_stats.json -config.yaml \ No newline at end of file +config.yaml +rikki_data.json \ No newline at end of file diff --git a/core/middleware/rikki.py b/core/middleware/rikki.py index 011947d..9f0fa72 100644 --- a/core/middleware/rikki.py +++ b/core/middleware/rikki.py @@ -1,5 +1,7 @@ import logging import time +import os +import json from collections import deque, defaultdict from dataclasses import dataclass @@ -24,10 +26,29 @@ class UserMetrics: neutral_count: int = 0 # 中性发言数(用于水群频率) + def to_dict(self) -> Dict: + """转换为字典""" + return { + 'cai_count': self.cai_count, + 'xm_count': self.xm_count, + 'nsfw_count': self.nsfw_count, + 'antisocial_count': self.antisocial_count, + 'total_count': self.total_count, + 'neutral_count': self.neutral_count + } + + @classmethod + def from_dict(cls, data: Dict) -> 'UserMetrics': + """从字典创建实例""" + return cls(**data) + + class RikkiMiddleware(BaseMiddleware): - def __init__(self, target_user_id: str = "5545347637"): - # 存储每个用户的触发几率,初始值在40-50%之间 + def __init__(self, target_user_id: str = "5545347637", data_file: str = "rikki_data.json"): + # 存储每个用户的触发几率 self.user_probabilities: Dict[str, float] = {} + self.data_file = data_file + # 触发关键词 self.xm_keywords = ["啃", "羡慕", "xm", "xmsl", "羡慕死了", "我菜"] self.cai_keywords = ["菜", "菜了", "菜死了", "我菜", "废物"] @@ -58,6 +79,44 @@ class RikkiMiddleware(BaseMiddleware): self.has_sent_warning: Dict[str, bool] = defaultdict(bool) + # 加载持久化数据 + self.load_data() + + def load_data(self) -> None: + """从JSON文件加载数据""" + if os.path.exists(self.data_file): + try: + with open(self.data_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 加载用户指标 + if 'user_metrics' in data: + for user_id, metrics_data in data['user_metrics'].items(): + self.user_metrics[user_id] = UserMetrics.from_dict(metrics_data) + + # 加载警告状态 + if 'has_sent_warning' in data: + self.has_sent_warning.update(data['has_sent_warning']) + + except (json.JSONDecodeError, KeyError) as e: + logging.warning(f"加载数据文件失败: {e}") + + def save_data(self,user_id:str = '5545347637') -> None: + """保存数据到JSON文件""" + try: + data = { + 'hit_prob': self.calculate_qianda_score(user_id), + 'user_metrics': {user_id: metrics.to_dict() + for user_id, metrics in self.user_metrics.items()}, + 'has_sent_warning': dict(self.has_sent_warning) + } + + with open(self.data_file, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + except Exception as e: + logging.error(f"保存数据文件失败: {e}") + def record_message(self, user_id: str) -> None: """记录用户发送消息的时间""" current_time = time.time() @@ -189,6 +248,8 @@ class RikkiMiddleware(BaseMiddleware): score = self.calculate_qianda_score(user_id) logging.debug("当前欠打的几率是{}".format(score)) + self.save_data() + return await handler(event, data) def get_user_status(self, user_id: str) -> str: diff --git a/core/rikki_hit.py b/core/rikki_hit.py index 2fa5d45..c00ac66 100644 --- a/core/rikki_hit.py +++ b/core/rikki_hit.py @@ -1,6 +1,23 @@ +import json + from aiogram.types import Message -from core.middleware.rikki import RikkiMiddleware async def handle_query_hit_command(message: Message) -> None: - hit_status = RikkiMiddleware().get_user_status("5545347637") - await message.reply(hit_status) \ No newline at end of file + hit_status = '' + with open('rikki_data.json', 'r', encoding='utf-8') as f: + hit_status = json.load(f) + _id = str(message.from_user.id) + user_data = hit_status['user_metrics'].get('5545347637', { + "cai_count": 0, + "xm_count": 0, + "nsfw_count": 0, + "antisocial_count": 0, + "total_count": 0, + "neutral_count": 0 + }) + + hit_prob = hit_status.get('hit_prob', 0.0) + + formatted_message = f"欠打度: {hit_prob:.2f}%\n卖菜: {user_data['cai_count']}, 羡慕: {user_data['xm_count']}, NSFW: {user_data['nsfw_count']}, 反社会: {user_data['antisocial_count']}, 中性: {user_data['neutral_count']}\n总发言: {user_data['total_count']}" + + await message.reply(formatted_message)