|
|
|
|
|
|
|
|
""" |
|
|
Qwen3-Reranker 推理测试代码 |
|
|
使用 RKLLM API 进行文本重排序推理 |
|
|
""" |
|
|
|
|
|
import faulthandler |
|
|
faulthandler.enable() |
|
|
import os |
|
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
|
import numpy as np |
|
|
import time |
|
|
import re |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from rkllm_binding import * |
|
|
|
|
|
|
|
|
class Qwen3RerankerTester: |
|
|
def __init__(self, model_path, library_path="./librkllmrt.so"): |
|
|
""" |
|
|
初始化 Qwen3 重排序模型测试器 |
|
|
|
|
|
Args: |
|
|
model_path: 模型文件路径(.rkllm 格式) |
|
|
library_path: RKLLM 库文件路径 |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.library_path = library_path |
|
|
self.runtime = None |
|
|
self.current_result = None |
|
|
|
|
|
|
|
|
self.system_prompt = "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"." |
|
|
|
|
|
|
|
|
|
|
|
self.yes_token_candidates = [9693] |
|
|
self.no_token_candidates = [2152] |
|
|
|
|
|
def callback_function(self, result_ptr, userdata_ptr, state_enum): |
|
|
""" |
|
|
推理回调函数 |
|
|
|
|
|
Args: |
|
|
result_ptr: 结果指针 |
|
|
userdata_ptr: 用户数据指针 |
|
|
state_enum: 状态枚举 |
|
|
""" |
|
|
state = LLMCallState(state_enum) |
|
|
|
|
|
if state == LLMCallState.RKLLM_RUN_NORMAL: |
|
|
result = result_ptr.contents |
|
|
print(f"result: {result}") |
|
|
|
|
|
|
|
|
if result.logits.logits and result.logits.vocab_size > 0: |
|
|
vocab_size = result.logits.vocab_size |
|
|
num_tokens = result.logits.num_tokens |
|
|
|
|
|
print(f"获取到 logits:vocab_size={vocab_size}, num_tokens={num_tokens}") |
|
|
|
|
|
|
|
|
if num_tokens > 0: |
|
|
last_token_logits = [] |
|
|
start_idx = (num_tokens - 1) * vocab_size |
|
|
for i in range(vocab_size): |
|
|
last_token_logits.append(result.logits.logits[start_idx + i]) |
|
|
|
|
|
self.current_result = { |
|
|
'logits': last_token_logits, |
|
|
'vocab_size': vocab_size, |
|
|
'num_tokens': num_tokens |
|
|
} |
|
|
|
|
|
print(f"最后一个 token 的 logits 范围: [{min(last_token_logits):.4f}, {max(last_token_logits):.4f}]") |
|
|
else: |
|
|
print("警告: 未能获取到 logits") |
|
|
|
|
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
|
print("推理过程发生错误") |
|
|
|
|
|
def find_best_yes_no_tokens(self, logits): |
|
|
""" |
|
|
找到最可能的 "yes" 和 "no" token IDs |
|
|
|
|
|
Args: |
|
|
logits: 词汇表大小的 logits 数组 |
|
|
|
|
|
Returns: |
|
|
(yes_token_id, no_token_id, yes_logit, no_logit) |
|
|
""" |
|
|
vocab_size = len(logits) |
|
|
|
|
|
|
|
|
best_yes_id = None |
|
|
best_yes_logit = float('-inf') |
|
|
for token_id in self.yes_token_candidates: |
|
|
if token_id < vocab_size: |
|
|
if logits[token_id] > best_yes_logit: |
|
|
best_yes_logit = logits[token_id] |
|
|
best_yes_id = token_id |
|
|
|
|
|
|
|
|
best_no_id = None |
|
|
best_no_logit = float('-inf') |
|
|
for token_id in self.no_token_candidates: |
|
|
if token_id < vocab_size: |
|
|
if logits[token_id] > best_no_logit: |
|
|
best_no_logit = logits[token_id] |
|
|
best_no_id = token_id |
|
|
|
|
|
|
|
|
if best_yes_id is None or best_no_id is None: |
|
|
print("警告: 使用启发式方法寻找 yes/no tokens") |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(logits)[::-1] |
|
|
top_tokens = sorted_indices[:20] |
|
|
|
|
|
|
|
|
if best_yes_id is None: |
|
|
best_yes_id = top_tokens[0] |
|
|
best_yes_logit = logits[best_yes_id] |
|
|
|
|
|
if best_no_id is None: |
|
|
|
|
|
best_no_id = top_tokens[min(10, len(top_tokens)-1)] |
|
|
best_no_logit = logits[best_no_id] |
|
|
|
|
|
return best_yes_id, best_no_id, best_yes_logit, best_no_logit |
|
|
|
|
|
def calculate_reranker_score(self, logits): |
|
|
""" |
|
|
计算重排序分数(基于 "yes" 和 "no" token 的 softmax 概率) |
|
|
|
|
|
Args: |
|
|
logits: 词汇表大小的 logits 数组 |
|
|
|
|
|
Returns: |
|
|
相关性分数 (0-1之间,越高越相关) |
|
|
""" |
|
|
try: |
|
|
|
|
|
yes_id, no_id, yes_logit, no_logit = self.find_best_yes_no_tokens(logits) |
|
|
|
|
|
print(f"Yes token ID: {yes_id}, logit: {yes_logit:.4f}") |
|
|
print(f"No token ID: {no_id}, logit: {no_logit:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
max_logit = max(yes_logit, no_logit) |
|
|
yes_exp = np.exp(yes_logit - max_logit) |
|
|
no_exp = np.exp(no_logit - max_logit) |
|
|
|
|
|
sum_exp = yes_exp + no_exp |
|
|
yes_prob = yes_exp / sum_exp |
|
|
no_prob = no_exp / sum_exp |
|
|
|
|
|
print(f"Yes 概率: {yes_prob:.4f}, No 概率: {no_prob:.4f}") |
|
|
|
|
|
|
|
|
return float(yes_prob) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"计算 reranker 分数时发生错误: {e}") |
|
|
|
|
|
return self.fallback_score_calculation(logits) |
|
|
|
|
|
def fallback_score_calculation(self, logits): |
|
|
""" |
|
|
备用分数计算方法(当无法找到 yes/no tokens 时) |
|
|
|
|
|
Args: |
|
|
logits: 词汇表大小的 logits 数组 |
|
|
|
|
|
Returns: |
|
|
相关性分数 (0-1之间) |
|
|
""" |
|
|
print("使用备用分数计算方法") |
|
|
|
|
|
|
|
|
logits_array = np.array(logits) |
|
|
|
|
|
|
|
|
softmax_probs = np.exp(logits_array - np.max(logits_array)) |
|
|
softmax_probs = softmax_probs / np.sum(softmax_probs) |
|
|
|
|
|
|
|
|
entropy = -np.sum(softmax_probs * np.log(softmax_probs + 1e-10)) |
|
|
max_entropy = np.log(len(logits)) |
|
|
normalized_entropy = entropy / max_entropy |
|
|
|
|
|
|
|
|
confidence_score = 1.0 - normalized_entropy |
|
|
|
|
|
|
|
|
max_logit_score = (np.max(logits_array) - np.mean(logits_array)) / (np.std(logits_array) + 1e-8) |
|
|
max_logit_score = max(0, min(1, max_logit_score / 10)) |
|
|
|
|
|
|
|
|
final_score = 0.7 * confidence_score + 0.3 * max_logit_score |
|
|
final_score = max(0.0, min(1.0, final_score)) |
|
|
|
|
|
print(f"备用计算 - 熵分数: {confidence_score:.4f}, 最大logit分数: {max_logit_score:.4f}, 最终分数: {final_score:.4f}") |
|
|
|
|
|
return final_score |
|
|
|
|
|
def init_model(self): |
|
|
"""初始化模型""" |
|
|
try: |
|
|
print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") |
|
|
self.runtime = RKLLMRuntime(self.library_path) |
|
|
|
|
|
print("创建默认参数...") |
|
|
params = self.runtime.create_default_param() |
|
|
|
|
|
|
|
|
params.model_path = self.model_path.encode('utf-8') |
|
|
params.max_context_len = 1024 |
|
|
params.max_new_tokens = 1 |
|
|
params.temperature = 0.0 |
|
|
params.top_k = 1 |
|
|
params.top_p = 1.0 |
|
|
|
|
|
|
|
|
params.extend_param.base_domain_id = 1 |
|
|
params.extend_param.embed_flash = 0 |
|
|
params.extend_param.enabled_cpus_num = 4 |
|
|
params.extend_param.enabled_cpus_mask = 0x0F |
|
|
|
|
|
print(f"初始化模型: {self.model_path}") |
|
|
self.runtime.init(params, self.callback_function) |
|
|
|
|
|
|
|
|
self.runtime.set_chat_template( |
|
|
"", |
|
|
"", |
|
|
"" |
|
|
) |
|
|
|
|
|
print("模型初始化成功!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"模型初始化失败: {e}") |
|
|
raise |
|
|
|
|
|
def format_rerank_input(self, instruction, query, document): |
|
|
""" |
|
|
格式化重排序输入(根据官方 README 格式) |
|
|
|
|
|
Args: |
|
|
instruction: 任务指令 |
|
|
query: 查询文本 |
|
|
document: 文档文本 |
|
|
|
|
|
Returns: |
|
|
格式化的输入文本 |
|
|
""" |
|
|
if instruction is None: |
|
|
instruction = 'Given a web search query, retrieve relevant passages that answer the query' |
|
|
|
|
|
|
|
|
formatted_input = f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}" |
|
|
return formatted_input |
|
|
|
|
|
def get_reranker_score(self, instruction, query, document): |
|
|
""" |
|
|
获取重排序分数(通过 logits) |
|
|
|
|
|
Args: |
|
|
instruction: 任务指令 |
|
|
query: 查询文本 |
|
|
document: 文档文本 |
|
|
|
|
|
Returns: |
|
|
相关性分数 (0-1之间) |
|
|
""" |
|
|
try: |
|
|
|
|
|
input_text = self.format_rerank_input(instruction, query, document) |
|
|
print(f"\n重排序输入: {input_text[:200]}{'...' if len(input_text) > 200 else ''}") |
|
|
|
|
|
|
|
|
rk_input = RKLLMInput() |
|
|
rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
|
|
c_prompt = input_text.encode('utf-8') |
|
|
rk_input._union_data.prompt_input = c_prompt |
|
|
|
|
|
|
|
|
infer_params = RKLLMInferParam() |
|
|
infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LOGITS |
|
|
infer_params.keep_history = 0 |
|
|
|
|
|
|
|
|
self.current_result = None |
|
|
self.runtime.clear_kv_cache(False) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
self.runtime.run(rk_input, infer_params) |
|
|
end_time = time.time() |
|
|
|
|
|
print(f"\n推理耗时: {end_time - start_time:.3f}秒") |
|
|
|
|
|
if self.current_result and 'logits' in self.current_result: |
|
|
|
|
|
logits = self.current_result['logits'] |
|
|
score = self.calculate_reranker_score(logits) |
|
|
|
|
|
print(f"计算得分: {score:.4f}") |
|
|
return score |
|
|
else: |
|
|
print("警告: 未能获取到有效的 logits,返回默认分数") |
|
|
return 0.0 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"重排序评分时发生错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return 0.0 |
|
|
|
|
|
def rerank_documents(self, query, documents, instruction=None): |
|
|
""" |
|
|
对文档列表进行重排序 |
|
|
|
|
|
Args: |
|
|
query: 查询文本 |
|
|
documents: 文档列表 |
|
|
instruction: 可选的任务指令 |
|
|
|
|
|
Returns: |
|
|
按相关性分数降序排列的(文档, 分数)元组列表 |
|
|
""" |
|
|
print(f"\n对 {len(documents)} 个文档进行重排序") |
|
|
print(f"查询: {query}") |
|
|
|
|
|
if instruction: |
|
|
print(f"指令: {instruction}") |
|
|
|
|
|
scored_docs = [] |
|
|
for i, doc in enumerate(documents): |
|
|
print(f"\n--- 处理文档 {i+1}/{len(documents)} ---") |
|
|
print(f"文档: {doc[:100]}{'...' if len(doc) > 100 else ''}") |
|
|
|
|
|
score = self.get_reranker_score(instruction, query, doc) |
|
|
scored_docs.append((doc, score)) |
|
|
print(f"得分: {score:.4f}") |
|
|
|
|
|
|
|
|
scored_docs.sort(key=lambda x: x[1], reverse=True) |
|
|
return scored_docs |
|
|
|
|
|
def test_basic_reranking(self): |
|
|
"""测试基础重排序功能""" |
|
|
print("\n" + "="*60) |
|
|
print("测试基础重排序功能") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
query = "What is the capital of China?" |
|
|
|
|
|
|
|
|
documents = [ |
|
|
"Beijing is the capital city of China, located in northern China.", |
|
|
"The Great Wall of China is an ancient fortification built to protect Chinese states.", |
|
|
"Python is a high-level programming language used for software development.", |
|
|
"China's capital Beijing is home to over 21 million people.", |
|
|
"Machine learning is a subset of artificial intelligence that uses algorithms." |
|
|
] |
|
|
|
|
|
|
|
|
instruction = "Given a web search query, retrieve relevant passages that answer the query" |
|
|
ranked_docs = self.rerank_documents(query, documents, instruction) |
|
|
|
|
|
|
|
|
print(f"\n重排序结果(查询: {query}):") |
|
|
print("-" * 80) |
|
|
for i, (doc, score) in enumerate(ranked_docs): |
|
|
print(f"排名 {i+1}: 分数 {score:.4f}") |
|
|
print(f"文档: {doc}") |
|
|
print() |
|
|
|
|
|
return ranked_docs |
|
|
|
|
|
def test_multilingual_reranking(self): |
|
|
"""测试多语言重排序""" |
|
|
print("\n" + "="*60) |
|
|
print("测试多语言重排序功能") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
query = "中国的首都是什么?" |
|
|
|
|
|
documents = [ |
|
|
"北京是中华人民共和国的首都,位于中国北部。", |
|
|
"上海是中国的经济中心,人口超过2400万。", |
|
|
"Python 是一种高级编程语言。", |
|
|
"The capital of China is Beijing.", |
|
|
"长城是中国古代的军事防御工程。" |
|
|
] |
|
|
|
|
|
instruction = "Given a web search query, retrieve relevant passages that answer the query" |
|
|
ranked_docs = self.rerank_documents(query, documents, instruction) |
|
|
|
|
|
print(f"\n多语言重排序结果(查询: {query}):") |
|
|
print("-" * 80) |
|
|
for i, (doc, score) in enumerate(ranked_docs): |
|
|
print(f"排名 {i+1}: 分数 {score:.4f}") |
|
|
print(f"文档: {doc}") |
|
|
print() |
|
|
|
|
|
return ranked_docs |
|
|
|
|
|
def test_domain_specific_reranking(self): |
|
|
"""测试领域特定的重排序""" |
|
|
print("\n" + "="*60) |
|
|
print("测试领域特定重排序(技术文档)") |
|
|
print("="*60) |
|
|
|
|
|
query = "How to implement a neural network in Python?" |
|
|
|
|
|
documents = [ |
|
|
"PyTorch is a deep learning framework that provides tensor computations with GPU acceleration.", |
|
|
"TensorFlow is an open-source machine learning library developed by Google.", |
|
|
"Neural networks are computing systems inspired by biological neural networks.", |
|
|
"Python is a programming language with simple syntax and powerful libraries.", |
|
|
"To implement a neural network in Python, you can use libraries like PyTorch or TensorFlow to define layers, loss functions, and optimization algorithms.", |
|
|
"Cooking recipes often require precise measurements and cooking times.", |
|
|
"Backpropagation is the algorithm used to train neural networks by computing gradients." |
|
|
] |
|
|
|
|
|
|
|
|
instruction = "Given a technical query and a document, determine if the document provides practical information for implementing the requested technical solution" |
|
|
|
|
|
ranked_docs = self.rerank_documents(query, documents, instruction) |
|
|
|
|
|
print(f"\n技术文档重排序结果(查询: {query}):") |
|
|
print("-" * 80) |
|
|
for i, (doc, score) in enumerate(ranked_docs): |
|
|
print(f"排名 {i+1}: 分数 {score:.4f}") |
|
|
print(f"文档: {doc}") |
|
|
print() |
|
|
|
|
|
return ranked_docs |
|
|
|
|
|
def test_comparison_with_official_example(self): |
|
|
"""测试与官方示例的对比""" |
|
|
print("\n" + "="*60) |
|
|
print("测试与官方示例的对比") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
task = 'Given a web search query, retrieve relevant passages that answer the query' |
|
|
|
|
|
queries = [ |
|
|
"What is the capital of China?", |
|
|
"Explain gravity", |
|
|
] |
|
|
|
|
|
documents = [ |
|
|
"The capital of China is Beijing.", |
|
|
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", |
|
|
] |
|
|
|
|
|
print("测试官方示例的查询-文档对:") |
|
|
for i, (query, doc) in enumerate(zip(queries, documents)): |
|
|
print(f"\n=== 查询-文档对 {i+1} ===") |
|
|
print(f"查询: {query}") |
|
|
print(f"文档: {doc}") |
|
|
|
|
|
score = self.get_reranker_score(task, query, doc) |
|
|
print(f"相关性分数: {score:.4f}") |
|
|
|
|
|
def cleanup(self): |
|
|
"""清理资源""" |
|
|
if self.runtime: |
|
|
try: |
|
|
self.runtime.destroy() |
|
|
print("模型资源已清理") |
|
|
except Exception as e: |
|
|
print(f"清理资源时发生错误: {e}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
import argparse |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Qwen3-Reranker-0.6B 推理测试') |
|
|
parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') |
|
|
parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.path.exists(args.model_path): |
|
|
print(f"错误: 模型文件不存在: {args.model_path}") |
|
|
print("请确保:") |
|
|
print("1. 已下载 Qwen3-Reranker-0.6B 模型") |
|
|
print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") |
|
|
return |
|
|
|
|
|
if not os.path.exists(args.library_path): |
|
|
print(f"错误: RKLLM 库文件不存在: {args.library_path}") |
|
|
print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") |
|
|
return |
|
|
|
|
|
print("Qwen3-Reranker-0.6B 推理测试") |
|
|
print("=" * 60) |
|
|
print("基于官方 README 的正确实现") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
tester = Qwen3RerankerTester(args.model_path, args.library_path) |
|
|
|
|
|
try: |
|
|
|
|
|
tester.init_model() |
|
|
|
|
|
|
|
|
print("\n开始运行重排序测试...") |
|
|
|
|
|
|
|
|
tester.test_comparison_with_official_example() |
|
|
|
|
|
|
|
|
tester.test_basic_reranking() |
|
|
|
|
|
|
|
|
tester.test_multilingual_reranking() |
|
|
|
|
|
|
|
|
tester.test_domain_specific_reranking() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("所有重排序测试完成!") |
|
|
print("="*60) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n测试被用户中断") |
|
|
except Exception as e: |
|
|
print(f"\n测试过程中发生错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
finally: |
|
|
|
|
|
tester.cleanup() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |