| # import os | |
| # import gradio as gr | |
| # from transformers import BlipProcessor ,BlipForConditionalGeneration | |
| # from PIL import Image | |
| # from transformers import CLIPProcessor, ChineseCLIPVisionModel ,AutoProcessor | |
| # | |
| # # 设置环境变量 HF_HOME 和 HF_ENDPOINT | |
| # # os.environ['HF_HOME'] = 'D:/AI/OCR/img2text/models' | |
| # # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' | |
| # | |
| # | |
| # # model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") | |
| # # processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16") | |
| # # 加载模型和处理器 | |
| # # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese") | |
| # model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese") | |
| # def generate_caption(image): | |
| # # 确保 image 是 PIL.Image 类型 | |
| # if not isinstance(image, Image.Image): | |
| # raise ValueError("Input must be a PIL.Image") | |
| # | |
| # inputs = processor(image, return_tensors="pt") | |
| # input_ids = inputs.get("input_ids") | |
| # if input_ids is None: | |
| # raise ValueError("Processor did not return input_ids") | |
| # | |
| # outputs = model.generate(input_ids=input_ids, max_length=50) | |
| # description = processor.decode(outputs[0], skip_special_tokens=True) | |
| # return description | |
| # | |
| # # 创建Gradio接口 | |
| # gradio_app = gr.Interface( | |
| # fn=generate_caption, | |
| # inputs=gr.Image(type="pil"), | |
| # outputs="text", | |
| # title="图片描述生成器", | |
| # description="上传一张图片,生成相应的描述。" | |
| # ) | |
| # | |
| # if __name__ == "__main__": | |
| # gradio_app.launch() | |
| import gradio as gr | |
| import torch | |
| import os | |
| from transformers import BlipForConditionalGeneration, BlipProcessor, GenerationConfig | |
| print(torch.__version__) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| _MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese' | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN) | |
| model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN).eval().to(device) | |
| # processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN) | |
| # model = BlipForConditionalGeneration.from_pretrained( | |
| # _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device) | |
| def inference(raw_image, model_n, strategy): | |
| if model_n == 'Image Captioning': | |
| inputs = processor(raw_image ,return_tensors= "pt").to(device) | |
| with torch.no_grad(): | |
| if strategy == "Beam search": | |
| # Beam search,即集束搜索,每次生成多个词,然后选择概率最大的前 k 个词,然后继续生成,直到生成结束 | |
| config = GenerationConfig( | |
| do_sample=False, | |
| num_beams=3, | |
| max_length=50, | |
| min_length=5, | |
| ) | |
| captions = model.generate(**inputs ,generation_config=config) | |
| else: | |
| # Nucleus sampling,即 top-p sampling,只保留累积概率大于 p 的词,然后重新归一化,得到一个新的概率分布,再从中采样,这样可以保证采样的结果更多样 | |
| config = GenerationConfig( | |
| do_sample=True, | |
| top_p=0.8, | |
| max_length=50, | |
| min_length=5, | |
| ) | |
| captions = model.generate(**inputs ,generation_config=config) | |
| caption = processor.decode(captions[0], skip_special_tokens=True) | |
| caption = caption.replace(' ', '') | |
| print(caption) | |
| return caption | |
| inputs = [ | |
| gr.Image(type='pil', label="Upload Image"), | |
| gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),# 任务选择,目前只有图片描述生成 | |
| gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")# 两种生成策略,Beam search 和 Nucleus sampling,前者生成的结果更准确,后者更多样 | |
| ] | |
| outputs = gr.Textbox(label="Output") | |
| title = "图片描述生成器" | |
| gradio_app=gr.Interface(inference, inputs, outputs, title=title, examples=[ | |
| ['demo.jpg', "Image Captioning", "Nucleus sampling"] | |
| ]) | |
| if __name__ == "__main__": | |
| gradio_app.launch() |