merve HF Staff commited on
Commit
73e457a
·
verified ·
1 Parent(s): 8682446

Upload train_qwen3_vl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen3_vl.py +79 -0
train_qwen3_vl.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "trackio",
6
+ # "transformers>=4.45.0",
7
+ # "torch",
8
+ # "datasets",
9
+ # "pillow",
10
+ # "qwen-vl-utils"
11
+ # ]
12
+ # ///
13
+
14
+ from datasets import load_dataset
15
+ from peft import LoraConfig
16
+ from trl import SFTTrainer, SFTConfig
17
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
18
+ import trackio
19
+ import torch
20
+
21
+ # Load 1% of the train split
22
+ print("Loading dataset...")
23
+ dataset = load_dataset("trl-lib/llava-instruct-mix", split="train[:1%]")
24
+
25
+ print(f"Dataset size: {len(dataset)} examples")
26
+
27
+ # Create a small eval split (10% of the 1%)
28
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
29
+ train_dataset = dataset_split["train"]
30
+ eval_dataset = dataset_split["test"]
31
+
32
+ print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
33
+
34
+ # Configure trainer with VL-specific settings
35
+ trainer = SFTTrainer(
36
+ model="Qwen/Qwen3-VL-3B-Instruct",
37
+ train_dataset=train_dataset,
38
+ eval_dataset=eval_dataset,
39
+ peft_config=LoraConfig(
40
+ r=16,
41
+ lora_alpha=32,
42
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
43
+ lora_dropout=0.05,
44
+ bias="none",
45
+ task_type="CAUSAL_LM"
46
+ ),
47
+ args=SFTConfig(
48
+ output_dir="qwen3-vl-3b-llava-instruct",
49
+ push_to_hub=True,
50
+ hub_model_id="merve/qwen3-vl-3b-llava-instruct",
51
+ num_train_epochs=3,
52
+ per_device_train_batch_size=1,
53
+ gradient_accumulation_steps=8,
54
+ gradient_checkpointing=True,
55
+ learning_rate=2e-4,
56
+ warmup_steps=100,
57
+ logging_steps=10,
58
+ eval_strategy="steps",
59
+ eval_steps=50,
60
+ save_strategy="steps",
61
+ save_steps=100,
62
+ save_total_limit=2,
63
+ bf16=True,
64
+ report_to="trackio",
65
+ project="qwen3-vl-finetuning",
66
+ run_name="qwen3-vl-3b-llava-1pct",
67
+ max_length=None, # Important for VL models - don't truncate image tokens
68
+ hub_strategy="every_save",
69
+ remove_unused_columns=False, # Keep all columns for VL processing
70
+ )
71
+ )
72
+
73
+ print("Starting training...")
74
+ trainer.train()
75
+
76
+ print("Pushing final model to Hub...")
77
+ trainer.push_to_hub()
78
+
79
+ print("Training complete!")