| from components.vector_db_operations import get_collection_from_vector_db | |
| from components.vector_db_operations import retrieval | |
| from components.english_information_extraction import english_information_extraction | |
| from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer | |
| from components.data_loading import preparing_data , loading_data | |
| from components.language_identification import language_identification | |
| def run_pipeline(input_text:str): | |
| language_dict = language_identification(input_text) | |
| language_code = next(iter(language_dict)) | |
| if language_code == "en": | |
| output_english = english_information_extraction(input_text) | |
| return output_english | |
| else: | |
| num_results = 1 | |
| path = "/content/drive/MyDrive/general_domains/vector_database" | |
| collection_name = "general_domains" | |
| collection = get_collection_from_vector_db(path , collection_name) | |
| domain , label_domain , distance = retrieval(input_text , num_results , collection ) | |
| if distance >1.45: | |
| domain = "undetermined" | |
| tokenizer , model = loading_model_and_tokenizer() | |
| df = preparing_data(input_text , label_domain) | |
| input_ids , input_masks , input_domains = loading_data(tokenizer , df ) | |
| labels = [] | |
| outputs = [] | |
| with torch.no_grad(): | |
| pred = model.forward(input_ids, input_masks , input_domains) | |
| labels.append([]) | |
| for output in pred: | |
| number = output.item() | |
| label = int(1) if number >= 0.5 else int(0) | |
| labels[-1].append(label) | |
| outputs.append(pred) | |
| discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]] | |
| return { "domain_label" :domain , | |
| "domain_score":distance , | |
| "discrimination_label" : discrimination_class[-1], | |
| "discrimination_score" : outputs[0][1:].item(), | |
| } | |