import numpy as np import torch from transformers import BertForQuestionAnswering, BertTokenizerFast # ── Config ─────────────────────────────────────────────────── MODEL_DIR = "model" MAX_LENGTH = 384 DOC_STRIDE = 128 N_BEST = 20 MAX_ANS_LEN = 30 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR) model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE) model.eval() print(f"✅ Model loaded on {DEVICE}") def answer_question(question: str, context: str) -> dict: inputs = tokenizer( question, context, max_length=MAX_LENGTH, truncation="only_second", stride=DOC_STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", return_tensors="pt", ) offset_mapping = inputs.pop("offset_mapping") # (n_chunks, seq_len, 2) sample_map = inputs.pop("overflow_to_sample_mapping") sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))] inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) start_logits = outputs.start_logits.cpu().numpy() # (n_chunks, seq_len) end_logits = outputs.end_logits.cpu().numpy() candidates = [] for chunk_idx in range(len(start_logits)): offsets = offset_mapping[chunk_idx].numpy() seq_ids = sequence_ids[chunk_idx] s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1] e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1] for s in s_indexes: for e in e_indexes: if seq_ids[s] != 1 or seq_ids[e] != 1: continue if e < s or e - s + 1 > MAX_ANS_LEN: continue candidates.append({ "score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]), "text": context[offsets[s][0]: offsets[e][1]], "start": int(offsets[s][0]), "end": int(offsets[e][1]), }) if not candidates: return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1} best = max(candidates, key=lambda x: x["score"]) return { "answer": best["text"], "score": round(best["score"], 4), "start": best["start"], "end": best["end"], } def ask(question: str, context: str): result = answer_question(question, context) print(f"❓ Question: {question}") print(f"💬 Answer : {result['answer']}") print(f"📊 Score : {result['score']}") print(f"📍 Position: Char {result['start']}–{result['end']}") print("-" * 60) ctx1 = """ The Amazon rainforest, also known as Amazonia, is a moist broadleaf tropical rainforest in the Amazon biome that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 km² of which 5,500,000 km² are covered by the rainforest. The majority of the forest is contained within Brazil, with 60% of the rainforest. """ ask("How much of the Amazon rainforest is in Brazil?", ctx1) ctx2 = """ The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It was constructed from 1887 to 1889 as the centerpiece of the 1889 World's Fair. The tower is 330 metres tall and is the tallest structure in Paris. """ ask("When was the Eiffel Tower built?", ctx2) ctx3 = """ Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation. Python is dynamically typed and garbage-collected. It supports multiple programming paradigms, including structured, object-oriented and functional programming. It was created by Guido van Rossum and first released in 1991. Python consistently ranks as one of the most popular programming languages. It is widely used in data science, machine learning, web development, and automation. The Python Package Index (PyPI) hosts hundreds of thousands of third-party modules. The standard library is very extensive, offering tools suited to many tasks. """ * 3 ask("When was Python first released?", ctx3) print("\n" + "=" * 60) print("🎮 Interactive mode – stop with 'quit'") print("=" * 60) context_interactive = input("📄 Input context:\n> ").strip() while True: q = input("\n❓ Question (or type 'quit'): ").strip() if q.lower() == "quit": print("👋 Bye.") break if not q: continue ask(q, context_interactive)