import os from typing import Optional from transformers import AutoModelForCausalLM, Qwen3ForCausalLM, AutoTokenizer, AutoConfig from huggingface_hub import hf_hub_download import torch import torch.nn as nn from warnings import warn # Define a custom model that wraps a causal LM and adds a regression head class CausalLMForRegression(nn.Module): config_class = Qwen3ForCausalLM.config_class base_model_prefix = "model" def __init__(self, base_model_name): super().__init__() # Load the causal LM with hidden states enabled self.model = AutoModelForCausalLM.from_pretrained( base_model_name, output_hidden_states=True ) self.base_model = base_model_name # Using pooled hidden state to a single scalar self.regression_head = nn.Linear(self.model.config.hidden_size, 1) print(f"Initializing difficulty scorer from scratch using {base_model_name} as a base!") self._keys_to_ignore_on_save = [] def forward(self, input_ids, attention_mask=None, labels=None): # Flatten extra dimensions if present if input_ids.dim() == 3: # e.g. from (accum_steps, batch_size, seq_length) to (accum_steps * batch_size, seq_length) input_ids = input_ids.view(-1, input_ids.size(-1)) if attention_mask is not None and attention_mask.dim() == 3: attention_mask = attention_mask.view(-1, attention_mask.size(-1)) outputs = self.model(input_ids, attention_mask=attention_mask) hidden_states = outputs.hidden_states[-1] # Now should have shape: (batch, seq_length, hidden_size) # Mean-pooling over non-padding tokens if attention_mask is not None: mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).to(hidden_states.dtype) hidden_sum = torch.sum(hidden_states * mask, dim=1) lengths = mask.sum(dim=1) pooled = hidden_sum / lengths else: pooled = hidden_states.mean(dim=1) logits = self.regression_head(pooled).squeeze(-1) loss = None if labels is not None: loss_fn = nn.HuberLoss() #nn.MSELoss() loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits} def get_input_embeddings(self): # Delegate to the underlying causal LM's get_input_embeddings method. return self.model.get_input_embeddings() def save_pretrained(self, output_dir, safe_serialization=False): os.makedirs(output_dir, exist_ok=True) # Ensure we are saving the entire model properly model_state_dict = self.model.state_dict() for key, value in model_state_dict.items(): if value.shape[0] == 0: print(f"Warning: Tensor {key} has shape {value.shape}, which may be problematic.") # Save model with proper weight tie handling self.model.save_pretrained(output_dir, safe_serialization=False) torch.save(self.regression_head.state_dict(), os.path.join(output_dir, "regression_head.bin")) def get_tokenizer(self): try: tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) print(f"Loaded tokenizer from {self.model.name_or_path}") except: tokenizer = AutoTokenizer.from_pretrained(self.base_model) print(f"Loaded tokenizer from {self.base_model}") return tokenizer @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): warn(f"The `from_pretrained` method is currently only implemented for models with Qwen3-base.") cfg = kwargs.pop("config", None) if cfg is None: cfg = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) cfg.output_hidden_states = True if "trust_remote_code" in kwargs: _ = kwargs.pop("trust_remote_code") backbone = Qwen3ForCausalLM.from_pretrained( pretrained_model_name_or_path, *model_args, config=cfg, trust_remote_code=False, **kwargs ) if os.path.isdir(pretrained_model_name_or_path): head_path = os.path.join(pretrained_model_name_or_path, "regression_head.bin") else: head_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="regression_head.bin", repo_type="model" ) inst = cls.__new__(cls) nn.Module.__init__(inst) inst.model = backbone inst.regression_head = nn.Linear(cfg.hidden_size, 1) inst._keys_to_ignore_on_save = [] inst.base_model = "Qwen/Qwen3-8B" if os.path.exists(head_path): inst.regression_head.load_state_dict( torch.load(head_path, map_location="cpu") ) else: print("'regression_head.bin' not found – initialising randomly.") return inst @torch.no_grad() def generate(self, *args, **kwargs): """ Wrapper that forwards all arguments to the underlying causal‑LM so that GenerationMixin‑based helpers (sampling, beam search, prepare_inputs_for_generation, etc.) keep working. """ return self.model.generate(*args, **kwargs) def prepare_inputs_for_generation(self, *args, **kwargs): """ Same here: to be able to load the model with AutoModelForCausalLM, we have to forward this method """ return self.model.prepare_inputs_for_generation(*args, **kwargs)