fix: query_hit returns all 0 and permanently store data

This commit is contained in:
草师傅 2025-08-06 11:12:45 +08:00
parent 24deaf811d
commit 45c5ed6389
3 changed files with 85 additions and 6 deletions

View file

@ -1,5 +1,7 @@
import logging
import time
import os
import json
from collections import deque, defaultdict
from dataclasses import dataclass
@ -24,10 +26,29 @@ class UserMetrics:
neutral_count: int = 0 # 中性发言数(用于水群频率)
def to_dict(self) -> Dict:
"""转换为字典"""
return {
'cai_count': self.cai_count,
'xm_count': self.xm_count,
'nsfw_count': self.nsfw_count,
'antisocial_count': self.antisocial_count,
'total_count': self.total_count,
'neutral_count': self.neutral_count
}
@classmethod
def from_dict(cls, data: Dict) -> 'UserMetrics':
"""从字典创建实例"""
return cls(**data)
class RikkiMiddleware(BaseMiddleware):
def __init__(self, target_user_id: str = "5545347637"):
# 存储每个用户的触发几率初始值在40-50%之间
def __init__(self, target_user_id: str = "5545347637", data_file: str = "rikki_data.json"):
# 存储每个用户的触发几率
self.user_probabilities: Dict[str, float] = {}
self.data_file = data_file
# 触发关键词
self.xm_keywords = ["", "羡慕", "xm", "xmsl", "羡慕死了", "我菜"]
self.cai_keywords = ["", "菜了", "菜死了", "我菜", "废物"]
@ -58,6 +79,44 @@ class RikkiMiddleware(BaseMiddleware):
self.has_sent_warning: Dict[str, bool] = defaultdict(bool)
# 加载持久化数据
self.load_data()
def load_data(self) -> None:
"""从JSON文件加载数据"""
if os.path.exists(self.data_file):
try:
with open(self.data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 加载用户指标
if 'user_metrics' in data:
for user_id, metrics_data in data['user_metrics'].items():
self.user_metrics[user_id] = UserMetrics.from_dict(metrics_data)
# 加载警告状态
if 'has_sent_warning' in data:
self.has_sent_warning.update(data['has_sent_warning'])
except (json.JSONDecodeError, KeyError) as e:
logging.warning(f"加载数据文件失败: {e}")
def save_data(self,user_id:str = '5545347637') -> None:
"""保存数据到JSON文件"""
try:
data = {
'hit_prob': self.calculate_qianda_score(user_id),
'user_metrics': {user_id: metrics.to_dict()
for user_id, metrics in self.user_metrics.items()},
'has_sent_warning': dict(self.has_sent_warning)
}
with open(self.data_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logging.error(f"保存数据文件失败: {e}")
def record_message(self, user_id: str) -> None:
"""记录用户发送消息的时间"""
current_time = time.time()
@ -189,6 +248,8 @@ class RikkiMiddleware(BaseMiddleware):
score = self.calculate_qianda_score(user_id)
logging.debug("当前欠打的几率是{}".format(score))
self.save_data()
return await handler(event, data)
def get_user_status(self, user_id: str) -> str: