| import argparse |
| import math |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| import copy |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.cuda.amp import autocast, GradScaler |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| from typing import List, Tuple |
| import sys |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train or Inference with World Model and Tree of Thought.') |
| parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path') |
|
|
| parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets') |
| parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') |
| parser.add_argument('--batch_size', type=int, default=4, help='Batch size') |
| parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs') |
| parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length') |
| parser.add_argument('--mcts_iterations', type=int, default=3, help='Number of MCTS Iterations') |
| parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS') |
| parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps') |
| parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') |
| parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay') |
| parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight') |
| parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight') |
| parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping') |
| parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models') |
| parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance') |
| parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Mode: train or inference') |
| parser.add_argument('--inference_mode', type=str, choices=['world_model', 'without_world_model', 'world_model_tree_of_thought'], default='world_model_tree_of_thought', help='Inference mode') |
| parser.add_argument('--query', type=str, default='', help='Input query for inference') |
| parser.add_argument('--train_mode', type=str, choices=['world_model', 'language_model'], default='language_model', help='Train world model or language model only') |
| parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search') |
| parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step') |
| parser.add_argument('--load_model', type=str, default=None, |
| help='Path to load saved model. If not provided, a new model will be initialized.') |
|
|
| parser.add_argument('--use_custom_data', action='store_true', help='Use custom data for training') |
|
|
| # Determine the base directory |
| if hasattr(sys, 'frozen') and hasattr(sys, '_MEIPASS'): |
| # PyInstaller creates a temp folder and stores path in _MEIPASS |
| base_dir = sys._MEIPASS |
| elif '__file__' in globals(): |
| # Running as a script |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| else: |
| # Running in an interactive environment (e.g., Jupyter, Colab) |
| base_dir = os.getcwd() |
|
|
| default_paths = [ |
| '/content/drive/MyDrive/lightbulb/knowledge_base.json', |
| '/content/drive/MyDrive/lightbulb/rag_cache.json', |
| '/content/drive/MyDrive/lightbulb/llm_training_data/llm_training_data.jsonl' |
| ] |
|
|
| parser.add_argument('--custom_data_paths', nargs='+', default=default_paths, |
| help='Paths to custom data files (relative to the script location or current working directory)') |
|
|
| args, unknown = parser.parse_known_args() |
|
|
| # Convert relative paths to absolute paths |
| args.custom_data_paths = [os.path.abspath(os.path.join(base_dir, path)) for path in args.custom_data_paths] |
|
|
| return args |
|
|
| import json |
| import jsonlines |
|
|
| def load_custom_data_from_files(file_paths): |
| custom_data = [] |
| for file_path in file_paths: |
| if file_path.endswith('.json'): |
| with open(file_path, 'r') as f: |
| data = json.load(f) |
| if isinstance(data, list): |
| custom_data.extend(data) |
| else: |
| custom_data.append(data) |
| elif file_path.endswith('.jsonl'): |
| with jsonlines.open(file_path) as reader: |
| custom_data.extend(reader) |
| return custom_data |
|
|
| def preprocess_custom_data(data_list): |
| processed_data = [] |
| for item in data_list: |
| # Check if the item is a string (JSON) |
| if isinstance(item, str): |
| try: |
| item = json.loads(item) |
| except json.JSONDecodeError: |
| print(f"Failed to parse JSON: {item[:100]}...") # Print first 100 chars for debugging |
| continue # Skip this item if it's not valid JSON |
|
|
| # Process query and content |
| query = item.get('query', '') |
| content = item.get('content', '') |
| if content == "RAG response generation failed.": |
| content = "" |
|
|
| # Combine query and content |
| combined_text = f"Query: {query} Content: {content}" |
|
|
| # Process numerical data (assuming these are available in the item dict) |
| episode_reward = item.get('episode_reward', 0) |
| loss = item.get('loss', 0) |
| cosine_similarity = item.get('cosine_similarity', 0) |
| rag_performance = item.get('rag_performance', 0) |
| ranking_model_performance = item.get('ranking_model_performance', 0) |
|
|
| # Create a dictionary with processed data |
| processed_item = { |
| 'text': combined_text, |
| 'episode_reward': episode_reward, |
| 'loss': loss, |
| 'cosine_similarity': cosine_similarity, |
| 'rag_performance': rag_performance, |
| 'ranking_model_performance': ranking_model_performance |
| } |
|
|
| processed_data.append(processed_item) |
|
|
| return processed_data |
|
|
| def load_custom_data(args, tokenizer, custom_data): |
| # Preprocess the custom data |
| processed_data = preprocess_custom_data(custom_data) |
|
|
| # Create a custom dataset |
| class CustomDataset(torch.utils.data.Dataset): |
| def __init__(self, data, tokenizer, max_length): |
| self.data = data |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| encoded = self.tokenizer.encode_plus( |
| item['text'], |
| max_length=self.max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| return { |
| 'input_ids': encoded['input_ids'].squeeze(), |
| 'attention_mask': encoded['attention_mask'].squeeze(), |
| 'episode_reward': torch.tensor(item['episode_reward'], dtype=torch.float), |
| 'loss': torch.tensor(item['loss'], dtype=torch.float), |
| 'cosine_similarity': torch.tensor(item['cosine_similarity'], dtype=torch.float), |
| 'rag_performance': torch.tensor(item['rag_performance'], dtype=torch.float), |
| 'ranking_model_performance': torch.tensor(item['ranking_model_performance'], dtype=torch.float) |
| } |
|
|
| # Create dataset and dataloader |
| dataset = CustomDataset(processed_data, tokenizer, args.max_length) |
|
|
| # Split the dataset into train and eval |
| train_size = int(0.8 * len(dataset)) |
| eval_size = len(dataset) - train_size |
| train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size]) |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=4 |
| ) |
| eval_loader = DataLoader( |
| eval_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=4 |
| ) |
|
|
| return train_loader, eval_loader |
|
|
|
|
|
|
| def load_data(args, tokenizer): |
| # Load the dataset |
| dataset = load_dataset(args.dataset_name, args.dataset_config) |
|
|
| # Ensure the tokenizer has a padding token |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| def tokenize_function(examples): |
| return tokenizer(examples['text'], truncation=True, max_length=args.max_length) |
|
|
| tokenized_datasets = dataset.map( |
| tokenize_function, |
| batched=True, |
| num_proc=4, |
| remove_columns=dataset['train'].column_names, |
| ) |
|
|
| # Build inputs and labels for language modeling |
| block_size = args.max_length |
|
|
| def group_texts(examples): |
| # Concatenate all texts |
| concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} |
| total_length = len(concatenated_examples['input_ids']) |
| # We drop the small remainder |
| total_length = (total_length |
| # Split by chunks of block_size |
| result = { |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| for k, t in concatenated_examples.items() |
| } |
| result['labels'] = result['input_ids'].copy() |
| return result |
|
|
| lm_datasets = tokenized_datasets.map( |
| group_texts, |
| batched=True, |
| num_proc=4, |
| ) |
|
|
| # Create DataLoader |
| train_dataset = lm_datasets['train'] |
| eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test'] |
|
|
| def data_collator(data): |
| return { |
| 'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long), |
| 'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long) |
| } |
|
|
| train_loader = DataLoader( |
| train_dataset, |
| shuffle=True, |
| batch_size=args.batch_size, |
| collate_fn=data_collator, |
| pin_memory=True, # Speeds up transfer to GPU |
| num_workers=4 |
| ) |
| eval_loader = DataLoader( |
| eval_dataset, |
| shuffle=False, |
| batch_size=args.batch_size, |
| collate_fn=data_collator, |
| pin_memory=True, |
| num_workers=4 |
| ) |
|
|
| return train_loader, eval_loader |
|
|
| def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch): |
| """ |
| Save all models to the specified directory. |
| |
| Args: |
| transformer_model (nn.Module): Transformer model. |
| representation_network (nn.Module): Representation network. |
| dynamics_network (nn.Module): Dynamics network. |
| prediction_network (nn.Module): Prediction network. |
| action_encoder (nn.Module): Action encoder. |
| save_dir (str): Directory to save the models. |
| epoch (int): Current epoch number. |
| """ |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt')) |
| torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt')) |
| torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt')) |
| torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt')) |
| torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt')) |
|
|
| print(f"All models saved for epoch {epoch}.") |
|
|
| class RotaryPositionalEncoding(nn.Module): |
| def __init__(self, d_model): |
| super(RotaryPositionalEncoding, self).__init__() |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model)) |
| self.register_buffer('inv_freq', inv_freq) |
|
|
| def forward(self, x): |
| seq_len, batch_size, _ = x.size() |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq) |
| sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2) |
| cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2) |
|
|
| x1 = x[..., 0::2] |
| x2 = x[..., 1::2] |
|
|
| # Apply rotation |
| x_rotated = torch.zeros_like(x) |
| x_rotated[..., 0::2] = x1 * cos - x2 * sin |
| x_rotated[..., 1::2] = x1 * sin + x2 * cos |
|
|
| return x_rotated |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, d_model, num_heads): |
| super(MultiHeadAttention, self).__init__() |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" |
| self.d_k = d_model |
| self.num_heads = num_heads |
| self.linear_q = nn.Linear(d_model, d_model) |
| self.linear_k = nn.Linear(d_model, d_model) |
| self.linear_v = nn.Linear(d_model, d_model) |
| self.linear_out = nn.Linear(d_model, d_model) |
|
|
| def forward(self, query, key, value, mask=None): |
| batch_size = query.size(0) |
| query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
| key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
| value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
|
|
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, -1e4) |
| attn = F.softmax(scores, dim=-1) |
| output = torch.matmul(attn, value) |
|
|
| output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) |
| return self.linear_out(output) |
|
|
| class MoE(nn.Module): |
| def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1): |
| super(MoE, self).__init__() |
| self.num_experts = num_experts |
| self.top_k = top_k |
| self.experts = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.GELU() if i % 2 == 0 else nn.SiLU(), |
| nn.Linear(d_ff, d_model) |
| ) |
| for i in range(num_experts) |
| ]) |
| self.gate = nn.Linear(d_model, num_experts) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| batch_size, seq_len, d_model = x.size() |
| # Compute gating scores |
| gate_scores = self.gate(x) # (batch_size, seq_len, num_experts) |
| top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k) |
| top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k) |
|
|
| # Initialize output |
| output = torch.zeros_like(x) |
|
|
| # Flatten batch and sequence dimensions |
| x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model) |
| output_flat = output.view(-1, d_model) |
| top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k) |
| top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k) |
|
|
| for k in range(self.top_k): |
| expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len) |
| expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len) |
| for e in range(self.num_experts): |
| mask = (expert_idx_flat == e) # Boolean mask |
| if mask.any(): |
| x_masked = x_flat[mask] # Select tokens for expert e |
| expert_output = self.experts[e](x_masked) # Apply expert e |
| output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output |
|
|
| output = output_flat.view(batch_size, seq_len, d_model) |
| return self.dropout(output) |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2): |
| super(TransformerBlock, self).__init__() |
| self.self_attention = MultiHeadAttention(d_model, num_heads) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.cross_attention = MultiHeadAttention(d_model, num_heads) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout) |
| self.norm3 = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, mask=None, enc_output=None, enc_mask=None): |
| # Self-attention |
| attn_output = self.self_attention(x, x, x, mask) |
| x = self.norm1(x + attn_output) |
| # Cross-attention (only in decoder) |
| if enc_output is not None: |
| cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask) |
| x = self.norm2(x + cross_attn_output) |
| # Feedforward/MoE |
| moe_output = self.moe(x) |
| return self.norm3(x + moe_output) |
|
|
| class Transformer(nn.Module): |
| def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2): |
| super(Transformer, self).__init__() |
| self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1) |
| self.rotary_positional_encoding = RotaryPositionalEncoding(d_model) |
| self.encoder_layers = nn.ModuleList( |
| [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)] |
| ) |
| self.decoder_layers = nn.ModuleList( |
| [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)] |
| ) |
| self.output_layer = nn.Linear(d_model, output_dim) |
| self.d_model = d_model |
|
|
| def forward(self, src, tgt, src_mask=None, tgt_mask=None): |
| # Encoder |
| src = self.embedding(src) * math.sqrt(self.d_model) |
| src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model) |
| src = self.rotary_positional_encoding(src) |
| src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model) |
| for layer in self.encoder_layers: |
| src = layer(src, src_mask) |
|
|
| # Decoder |
| tgt = self.embedding(tgt) * math.sqrt(self.d_model) |
| tgt = tgt.transpose(0, 1) |
| tgt = self.rotary_positional_encoding(tgt) |
| tgt = tgt.transpose(0, 1) |
| for layer in self.decoder_layers: |
| tgt = layer(tgt, tgt_mask, src, src_mask) |
| output = self.output_layer(tgt) |
| return output |
|
|
| def generate_with_beam_search(self, src, tokenizer, beam_size=5, max_length=20, n_tokens_predict=3, temperature=1.0): |
| """ |
| Generate sequences using beam search with multi-token prediction. |
| |
| Args: |
| src (torch.Tensor): Source input tensor of shape (batch_size, seq_len) |
| tokenizer: Tokenizer to access special tokens |
| beam_size (int): Size of the beam for beam search |
| max_length (int): Maximum length of the generated sequence |
| n_tokens_predict (int): Number of tokens to predict at each step |
| temperature (float): Temperature parameter for softmax |
| |
| Returns: |
| List[Tuple[torch.Tensor, float]]: List of (sequence, score) tuples |
| """ |
| batch_size = src.size(0) |
| device = src.device |
| vocab_size = self.output_layer.out_features |
|
|
| # Encode the source |
| src_enc = self.encode(src) |
|
|
| # Initialize beam |
| beam = [(torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=device), |
| 0.0, # log probability |
| torch.zeros(batch_size, device=device), # cumulative entropy |
| torch.zeros(batch_size, device=device))] # cumulative variance |
|
|
| for _ in range(max_length |
| all_candidates = [] |
| for seq, score, cum_entropy, cum_variance in beam: |
| if seq[:, -1].item() == tokenizer.eos_token_id: |
| all_candidates.append((seq, score, cum_entropy, cum_variance)) |
| continue |
|
|
| # Predict next n tokens |
| logits = self.predict_next_n_tokens(src_enc, seq, n_tokens_predict) |
|
|
| # Calculate probabilities, entropy, and variance |
| probs = F.softmax(logits / temperature, dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| variance = torch.var(probs, dim=-1) |
|
|
| # Sample top-k tokens for each position |
| topk_probs, topk_indices = torch.topk(probs, k=beam_size, dim=-1) |
|
|
| # Generate all possible continuations |
| for i in range(beam_size ** n_tokens_predict): |
| indices = [i // (beam_size ** j) % beam_size for j in range(n_tokens_predict)] |
| new_tokens = topk_indices[:, range(n_tokens_predict), indices] |
| new_seq = torch.cat([seq, new_tokens], dim=-1) |
| new_score = score + torch.sum(torch.log(topk_probs[:, range(n_tokens_predict), indices])) |
| new_entropy = cum_entropy + torch.sum(entropy[:, indices]) |
| new_variance = cum_variance + torch.sum(variance[:, indices]) |
|
|
| all_candidates.append((new_seq, new_score, new_entropy, new_variance)) |
|
|
| # Select top beam_size candidates |
| beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:beam_size] |
|
|
| # Stop if all beams have ended |
| if all(seq[:, -1].item() == tokenizer.eos_token_id for seq, _, _, _ in beam): |
| break |
|
|
| return [(seq, score) for seq, score, _, _ in beam] |
|
|
| def encode(self, src): |
| src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| src_emb = src_emb.transpose(0, 1) |
| src_emb = self.rotary_positional_encoding(src_emb) |
| src_emb = src_emb.transpose(0, 1) |
| src_enc = src_emb |
| for layer in self.encoder_layers: |
| src_enc = layer(src_enc) |
| return src_enc |
|
|
| def predict_next_n_tokens(self, src_enc, tgt_seq, n_tokens): |
| tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model) |
| tgt_emb = tgt_emb.transpose(0, 1) |
| tgt_emb = self.rotary_positional_encoding(tgt_emb) |
| tgt_emb = tgt_emb.transpose(0, 1) |
| tgt_dec = tgt_emb |
| for layer in self.decoder_layers: |
| tgt_dec = layer(tgt_dec, None, src_enc, None) |
| output = self.output_layer(tgt_dec[:, -1:]) |
| return output.repeat(1, n_tokens, 1) |
|
|
| # Objective Functions |
|
|
| class InfoNCE_Loss(nn.Module): |
| def __init__(self, temperature=0.07): |
| super(InfoNCE_Loss, self).__init__() |
| self.temperature = temperature |
| self.cross_entropy = nn.CrossEntropyLoss() |
|
|
| def forward(self, z_i, z_j): |
| """ |
| Args: |
| z_i (torch.Tensor): Flattened representations from view i, shape (2n, embed_dim) |
| z_j (torch.Tensor): Flattened representations from view j, shape (2n, embed_dim) |
| |
| Returns: |
| torch.Tensor: InfoNCE loss |
| """ |
| n = z_i.size(0) |
| z = torch.cat([z_i, z_j], dim=0) # Shape: (2n, embed_dim) |
|
|
| z = F.normalize(z, dim=1) |
| similarity_matrix = torch.matmul(z, z.T) # Shape: (2n, 2n) |
|
|
| # Create a mask to exclude self-similarity |
| mask = torch.eye(2 * n, device=z.device, dtype=torch.bool) |
| similarity_matrix = similarity_matrix.masked_fill(mask, -1e4) # Use a manageable negative value |
|
|
| # Create labels for contrastive learning |
| labels = torch.arange(n, device=z.device) |
| labels = torch.cat([labels + n, labels], dim=0) # Shape: (2n,) |
|
|
| # Apply temperature scaling |
| similarity_matrix /= self.temperature |
|
|
| # Compute cross-entropy loss |
| loss = self.cross_entropy(similarity_matrix, labels) |
| return loss |
|
|
| class CovarianceRegularization(nn.Module): |
| def __init__(self, lambda_reg=1e-3): |
| super(CovarianceRegularization, self).__init__() |
| self.lambda_reg = lambda_reg |
|
|
| def forward(self, embeddings): |
| """ |
| Args: |
| embeddings (torch.Tensor): Embedding tensor, shape (batch_size, embed_dim) |
| |
| Returns: |
| torch.Tensor: Covariance regularization loss |
| """ |
| batch_size, embed_dim = embeddings.size() |
| mean = embeddings.mean(dim=0) |
| embeddings_centered = embeddings - mean |
| cov = (embeddings_centered.T @ embeddings_centered) / (batch_size - 1) |
| cov_loss = torch.sum(cov ** 2) - torch.sum(torch.diag(cov) ** 2) |
| return self.lambda_reg * cov_loss |
|
|
| class DynamicsPerformanceLoss(nn.Module): |
| def __init__(self, lambda_var=1e-3): |
| super(DynamicsPerformanceLoss, self).__init__() |
| self.lambda_var = lambda_var |
|
|
| def forward(self, true_next_state, predicted_next_state): |
| """ |
| Args: |
| true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim) |
| predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim) |
| |
| Returns: |
| torch.Tensor: Dynamics performance loss |
| """ |
| mse_loss = F.mse_loss(predicted_next_state, true_next_state) |
| variance_loss = torch.var(predicted_next_state, dim=0).mean() |
| return mse_loss + self.lambda_var * variance_loss |
|
|
| class ThoughtConsistencyLoss(nn.Module): |
| def __init__(self): |
| super(ThoughtConsistencyLoss, self).__init__() |
|
|
| def forward(self, true_next_state, perturbed_next_state): |
| """ |
| Args: |
| true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim) |
| perturbed_next_state (torch.Tensor): Perturbed next state, shape (batch_size, state_dim) |
| |
| Returns: |
| torch.Tensor: Thought-consistency loss |
| """ |
| return F.mse_loss(true_next_state, perturbed_next_state) |
|
|
| class PolicyValueJointLoss(nn.Module): |
| def __init__(self, lambda_value=0.5): |
| super(PolicyValueJointLoss, self).__init__() |
| self.lambda_value = lambda_value |
| self.cross_entropy = nn.CrossEntropyLoss() |
| self.mse_loss = nn.MSELoss() |
|
|
| def forward(self, policy_logits, true_policy, value_pred, true_value): |
| """ |
| Args: |
| policy_logits (torch.Tensor): Logits from the policy network, shape (batch_size * seq_len, num_actions) |
| true_policy (torch.Tensor): Ground truth policy, shape (batch_size * seq_len, num_actions) |
| value_pred (torch.Tensor): Predicted values, shape (batch_size * seq_len) |
| true_value (torch.Tensor): Ground truth values, shape (batch_size * seq_len) |
| |
| Returns: |
| torch.Tensor: Combined policy and value loss |
| """ |
| policy_logits = policy_logits.reshape(-1, policy_logits.size(-1)) |
| true_policy = true_policy.reshape(-1, true_policy.size(-1)) |
| value_pred = value_pred.reshape(-1) |
| true_value = true_value.reshape(-1) |
|
|
|
|
| policy_loss = self.cross_entropy(policy_logits, true_policy.argmax(dim=1)) |
| value_loss = self.mse_loss(value_pred, true_value) |
| return policy_loss + self.lambda_value * value_loss |
|
|
| class ActionDiversityReward(nn.Module): |
| def __init__(self, lambda_div=1e-3): |
| super(ActionDiversityReward, self).__init__() |
| self.lambda_div = lambda_div |
|
|
| def forward(self, action_embeddings): |
| """ |
| Args: |
| action_embeddings (torch.Tensor): Embeddings of actions, shape (batch_size, embed_dim) |
| |
| Returns: |
| torch.Tensor: Action diversity loss |
| """ |
| similarity_matrix = F.cosine_similarity(action_embeddings.unsqueeze(1), action_embeddings.unsqueeze(0), dim=2) |
| # Zero out self-similarity |
| similarity_matrix = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(action_embeddings.device) |
| diversity_loss = torch.sum(similarity_matrix ** 2) |
| return self.lambda_div * diversity_loss |
|
|
| class ExpectedThoughtValueLoss(nn.Module): |
| def __init__(self): |
| super(ExpectedThoughtValueLoss, self).__init__() |
|
|
| def forward(self, mcts_best_values): |
| """ |
| Args: |
| mcts_best_values (torch.Tensor): Best values from MCTS, shape (batch_size) |
| |
| Returns: |
| torch.Tensor: ETV loss |
| """ |
| return -mcts_best_values.mean() |
|
|
| class ExplorationRegularization(nn.Module): |
| def __init__(self, lambda_expl=1e-3): |
| super(ExplorationRegularization, self).__init__() |
| self.lambda_expl = lambda_expl |
|
|
| def forward(self, visit_counts): |
| """ |
| Args: |
| visit_counts (torch.Tensor): Visit counts for actions, shape (batch_size, num_actions) |
| |
| Returns: |
| torch.Tensor: Exploration regularization loss |
| """ |
| reward = torch.sum(1.0 / (visit_counts + 1), dim=-1) |
| return self.lambda_expl * reward.mean() |
|
|
| class KL_DivergenceLoss(nn.Module): |
| def __init__(self): |
| super(KL_DivergenceLoss, self).__init__() |
|
|
| def forward(self, old_policy, new_policy): |
| """ |
| Args: |
| old_policy (torch.Tensor): Old policy probabilities, shape (batch_size, num_actions) |
| new_policy (torch.Tensor): New policy probabilities, shape (batch_size, num_actions) |
| |
| Returns: |
| torch.Tensor: KL divergence loss |
| """ |
| kl_div = F.kl_div(new_policy.log(), old_policy, reduction='batchmean') |
| return kl_div |
|
|
| # MuZero Components |
|
|
| class ActionEncoder(nn.Module): |
| def __init__(self, action_vocab_size, embed_dim): |
| super(ActionEncoder, self).__init__() |
| self.embedding = nn.Embedding(action_vocab_size, embed_dim) |
|
|
| def forward(self, action_indices): |
| """ |
| Args: |
| action_indices (torch.Tensor): Tensor of shape (batch_size, seq_len) |
| |
| Returns: |
| torch.Tensor: Encoded actions of shape (batch_size, seq_len, embed_dim) |
| """ |
| return self.embedding(action_indices) |
|
|
| class RepresentationNetwork(nn.Module): |
| def __init__(self, vocab_dim, d_model, state_dim): |
| super(RepresentationNetwork, self).__init__() |
| self.proj = nn.Linear(vocab_dim, d_model) # Project from vocab_dim to d_model |
| self.linear = nn.Linear(d_model, state_dim) # Project from d_model to state_dim |
| self.norm = nn.LayerNorm(state_dim) |
|
|
| def forward(self, transformer_output): |
| """ |
| Args: |
| transformer_output (torch.Tensor): Shape (batch_size, seq_len, vocab_dim) |
| |
| Returns: |
| torch.Tensor: Encoded state of shape (batch_size, seq_len, state_dim) |
| """ |
| # First project down from vocab_dim to d_model |
| projected_output = self.proj(transformer_output) # Shape: (batch_size, seq_len, d_model) |
| # Then project down from d_model to state_dim |
| state = self.linear(projected_output) # Shape: (batch_size, seq_len, state_dim) |
| state = self.norm(state) # Shape: (batch_size, seq_len, state_dim) |
| return state |
|
|
|
|
| class DynamicsNetwork(nn.Module): |
| def __init__(self, state_dim, action_dim, hidden_dim): |
| super(DynamicsNetwork, self).__init__() |
| self.rms_norm = nn.LayerNorm(state_dim) |
| self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) |
| self.activation = nn.GELU() |
| self.fc2 = nn.Linear(hidden_dim, state_dim) |
|
|
| def forward(self, state, action): |
| """ |
| Args: |
| state (torch.Tensor): Current state, shape (batch_size, seq_len, state_dim) |
| action (torch.Tensor): Action embedding, shape (batch_size, seq_len, action_dim) |
| |
| Returns: |
| torch.Tensor: Predicted next state, shape (batch_size, seq_len, state_dim) |
| """ |
| norm_state = self.rms_norm(state) |
| combined = torch.cat([norm_state, action], dim=-1) |
| hidden = self.activation(self.fc1(combined)) |
| next_state = self.fc2(hidden) |
| return next_state |
|
|
| class PredictionNetwork(nn.Module): |
| def __init__(self, state_dim, action_vocab_size, value_dim): |
| super(PredictionNetwork, self).__init__() |
| self.state_dim = state_dim |
| self.rms_norm = nn.LayerNorm(state_dim) |
| self.policy_head = nn.Linear(state_dim, action_vocab_size) # Output size is action_vocab_size |
| self.value_head = nn.Linear(state_dim, value_dim) |
|
|
| def forward(self, state): |
| """ |
| Args: |
| state (torch.Tensor): State representation, shape (batch_size, state_dim) |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Policy logits and value estimates |
| """ |
| norm_state = self.rms_norm(state) |
| policy_logits = self.policy_head(norm_state) # Shape: (batch_size, action_vocab_size) |
| value_estimates = self.value_head(norm_state).squeeze(-1) # Shape: (batch_size) |
| return policy_logits, value_estimates |
|
|
|
|
|
|
|
|
| class MCTSNode: |
| __slots__ = [ |
| 'state', |
| 'parent', |
| 'action', |
| 'children', |
| 'visit_count', |
| 'value_sum', |
| 'prior', |
| 'cached_policy', |
| 'cached_value', |
| 'thought_node', |
| 'entropy', |
| 'variance' |
| ] |
|
|
| def __init__(self, state, thought_node, parent=None, action=None): |
| self.state = state |
| self.thought_node = thought_node |
| self.parent = parent |
| self.action = action |
| self.children = {} |
| self.visit_count = 0 |
| self.value_sum = 0.0 |
| self.prior = 0.0 |
| self.cached_policy = None |
| self.cached_value = None |
| self.entropy = 0.0 |
| self.variance = 0.0 |
|
|
| def expand(self, priors): |
| for child_thought_node in self.thought_node.children: |
| action = child_thought_node.name |
| if action not in self.children: |
| child_state = self.state.apply_action(action) |
| child_node = MCTSNode( |
| state=child_state, |
| thought_node=child_thought_node, |
| parent=self, |
| action=action |
| ) |
| child_node.prior = priors.get(action, 1.0 / len(self.thought_node.children)) |
| self.children[action] = child_node |
|
|
| def is_leaf(self): |
| return len(self.children) == 0 |
|
|
| def ucb_score(self, total_visits, exploration_constant=math.sqrt(2)): |
| if self.visit_count == 0: |
| return float('inf') # Ensure unvisited nodes are selected first |
| avg_value = self.value_sum / self.visit_count |
| exploration_term = exploration_constant * self.prior * math.sqrt(total_visits) / (1 + self.visit_count) |
| entropy_term = -0.1 * self.entropy # Slightly prefer lower entropy |
| variance_term = 0.05 * self.variance # Slightly prefer higher variance |
| return avg_value + exploration_term + entropy_term + variance_term |
|
|
|
|
| class MCTS: |
| def __init__(self, prediction_network, dynamics_network, action_encoder, num_iterations=10, exploration_constant=math.sqrt(2), beam_size=5, n_tokens_predict=3): |
| self.prediction_network = prediction_network |
| self.dynamics_network = dynamics_network |
| self.action_encoder = action_encoder |
| self.num_iterations = num_iterations |
| self.exploration_constant = exploration_constant |
| self.beam_size = beam_size |
| self.n_tokens_predict = n_tokens_predict |
| self.cache = {} |
|
|
| def search_with_beam(self, root_state): |
| root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node) |
|
|
| # Evaluate the root node and backpropagate |
| value_estimate = self.evaluate(root_node) # Evaluate and expand root_node |
| self.backpropagate(root_node, value_estimate) # Backpropagate the value |
|
|
| beam = [(root_node, 0.0, 0.0, 0.0, [])] # (node, score, cum_entropy, cum_variance, action_sequence) |
|
|
| for iteration in range(self.num_iterations): |
| all_candidates = [] |
| for node, score, cum_entropy, cum_variance, action_sequence in beam: |
| if node.is_leaf(): |
| value_estimate = self.evaluate(node) |
| self.backpropagate(node, value_estimate) # Backpropagate after evaluation |
| if len(node.children) == 0: |
| continue # No children to expand |
|
|
| total_visits = sum(child.visit_count for child in node.children.values()) |
| # Select top actions based on UCB score |
| sorted_children = sorted( |
| node.children.items(), |
| key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant), |
| reverse=True |
| )[:self.beam_size] |
|
|
| for selected_action, selected_node in sorted_children: |
| current_node = selected_node |
| current_sequence = action_sequence + [selected_action] |
| current_score = score |
| current_entropy = cum_entropy + selected_node.entropy |
| current_variance = cum_variance + selected_node.variance |
|
|
| # Predict n_tokens_predict actions |
| for _ in range(self.n_tokens_predict): |
| if current_node.is_leaf(): |
| value_estimate = self.evaluate(current_node) |
| self.backpropagate(current_node, value_estimate) # Backpropagate after evaluation |
| if len(current_node.children) == 0: |
| break # No more actions |
| total_visits = sum(child.visit_count for child in current_node.children.values()) |
| next_action, next_node = max( |
| current_node.children.items(), |
| key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant) |
| ) |
| current_sequence.append(next_action) |
|
|
| # Prevent division by zero by ensuring visit_count > 0 |
| if next_node.visit_count > 0: |
| current_score += next_node.value_sum / next_node.visit_count |
| else: |
| # Assign a default value or handle the zero division case |
| current_score += 0.0 # Alternatively, use a small epsilon or skip |
|
|
| current_entropy += next_node.entropy |
| current_variance += next_node.variance |
| current_node = next_node |
|
|
| all_candidates.append((current_node, current_score, current_entropy, current_variance, current_sequence)) |
|
|
| if not all_candidates: |
| break # No more candidates to expand |
|
|
| # Select top beam_size candidates |
| beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:self.beam_size] |
| print(f"Iteration {iteration + 1}: Beam size after sorting: {len(beam)}") # Debug |
|
|
| if beam: |
| best_sequence = beam[0][4] |
| return best_sequence |
| else: |
| return [] |
|
|
|
|
|
|
| def search(self, root_state): |
| root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node) |
|
|
| for _ in range(self.num_iterations): |
| node = self.select(root_node) |
| value = self.evaluate(node) |
| self.backpropagate(node, value) |
|
|
| return self.best_action_sequence(root_node) |
|
|
| def select(self, node): |
| while not node.is_leaf(): |
| total_visits = sum(child.visit_count for child in node.children.values()) |
| _, node = max( |
| node.children.items(), |
| key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant) |
| ) |
| return node |
|
|
| def evaluate(self, node): |
| # Extract the last time step |
| state_representation = node.state.representation[:, -1, :] # Shape: (batch_size=1, state_dim) |
| print(f"Evaluating node with state_representation shape: {state_representation.shape}") # Debug |
| policy_logits, value_estimate = self.prediction_network(state_representation) |
| print(f"Policy logits shape: {policy_logits.shape}, Value estimate shape: {value_estimate.shape}") # Debug |
| value_estimate = value_estimate.item() # Now safe as batch_size=1 |
|
|
| policy_probs = F.softmax(policy_logits, dim=-1).squeeze(0) # Shape: (action_vocab_size,) |
| print(f"Policy probabilities shape: {policy_probs.shape}") # Debug |
|
|
| priors = {} |
| for child in node.thought_node.children: |
| action_name = child.name |
| action_idx = action_to_index.get(action_name, None) |
| if action_idx is not None and action_idx < policy_probs.size(0): |
| priors[action_name] = policy_probs[action_idx].item() |
| else: |
| priors[action_name] = 1.0 / len(node.thought_node.children) |
|
|
| node.expand(priors) |
|
|
| # Calculate entropy and variance |
| entropy = -torch.sum(policy_probs * torch.log(policy_probs + 1e-9)) |
| variance = torch.var(policy_probs) |
| node.entropy = entropy.item() |
| node.variance = variance.item() |
|
|
| print(f"Node entropy: {node.entropy}, variance: {node.variance}") # Debug |
|
|
| return value_estimate # Return the value estimate for backpropagation |
|
|
|
|
| def backpropagate(self, node, value): |
| while node is not None: |
| node.visit_count += 1 |
| node.value_sum += value |
| node = node.parent |
|
|
| def best_action_sequence(self, root_node): |
| sequences = [] |
| self._generate_sequences(root_node, [], sequences) |
|
|
| # Score sequences based on visit counts, entropy, and variance |
| scored_sequences = [] |
| for seq in sequences: |
| score = sum(node.visit_count for node in seq) |
| entropy = sum(node.entropy for node in seq) |
| variance = sum(node.variance for node in seq) |
| adjusted_score = score - 0.1 * entropy + 0.05 * variance |
| scored_sequences.append((seq, adjusted_score)) |
|
|
| # Sort sequences by adjusted score and select top beam_size |
| best_sequences = sorted(scored_sequences, key=lambda x: x[1], reverse=True)[:self.beam_size] |
|
|
| # Return the actions of the best sequence |
| best_sequence = best_sequences[0][0] |
| return [node.action for node in best_sequence[1:self.n_tokens_predict+1]] # Exclude root node |
|
|
| def _generate_sequences(self, node, current_sequence, sequences): |
| current_sequence.append(node) |
| if len(current_sequence) > self.n_tokens_predict or not node.children: |
| sequences.append(current_sequence) |
| else: |
| for child in node.children.values(): |
| self._generate_sequences(child, current_sequence.copy(), sequences) |
|
|
| class State: |
| def __init__(self, representation, dynamics_network, action_encoder, thought_node): |
| self.representation = representation |
| self.dynamics_network = dynamics_network |
| self.action_encoder = action_encoder |
| self.thought_node = thought_node |
|
|
| def apply_action(self, action): |
| next_thought_node = None |
| for child in self.thought_node.children: |
| if child.name == action: |
| next_thought_node = child |
| break |
| if next_thought_node is None: |
| raise ValueError(f"Action '{action}' is not valid from the current thought node.") |
|
|
| # Adjust action_index and action_embedding shapes |
| action_index = torch.tensor([action_to_index[action]], device=self.representation.device) |
| action_embedding = self.action_encoder(action_index) # Shape: (batch_size=1, action_dim) |
|
|
| # Extract the last time step of the state |
| state = self.representation[:, -1, :] # Shape: (batch_size, state_dim) |
|
|
| # Ensure action_embedding matches the state dimension |
| next_state_representation = self.dynamics_network(state, action_embedding) # Shape: (batch_size, state_dim) |
|
|
| # Append the new state to the representation history |
| new_representation = torch.cat([self.representation, next_state_representation.unsqueeze(1)], dim=1) # Shape: (batch_size, seq_len+1, state_dim) |
|
|
| return State( |
| representation=new_representation, |
| dynamics_network=self.dynamics_network, |
| action_encoder=self.action_encoder, |
| thought_node=next_thought_node |
| ) |
|
|
| class PPOAgent: |
| def __init__(self, policy_network, optimizer, clip_epsilon=0.2, entropy_coef=0.01, value_coef=0.5): |
| self.policy_network = policy_network |
| self.optimizer = optimizer |
| self.clip_epsilon = clip_epsilon |
| self.entropy_coef = entropy_coef |
| self.value_coef = value_coef |
|
|
| def compute_loss(self, states, old_log_probs, actions, returns, advantages): |
| # Get policy logits and value estimates |
| policy_logits, value_estimates = self.policy_network(states) |
|
|
| # Flatten all tensors |
| policy_logits = policy_logits.reshape(-1, policy_logits.size(-1)) |
| value_estimates = value_estimates.reshape(-1) |
| actions = actions.reshape(-1) |
| old_log_probs = old_log_probs.reshape(-1) |
| returns = returns.reshape(-1) |
| advantages = advantages.reshape(-1) |
|
|
| # Ensure all tensors have the same first dimension |
| assert policy_logits.size(0) == value_estimates.size(0) == actions.size(0) == old_log_probs.size(0) == returns.size(0) == advantages.size(0), "Tensor sizes mismatch" |
|
|
| # Compute new log probabilities |
| new_log_probs_all = F.log_softmax(policy_logits, dim=-1) |
| new_log_probs = new_log_probs_all.gather(1, actions.unsqueeze(-1)).squeeze(-1) |
|
|
| # Compute ratios |
| ratios = torch.exp(new_log_probs - old_log_probs) |
|
|
| # PPO surrogate loss |
| surr1 = ratios * advantages |
| surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages |
| policy_loss = -torch.min(surr1, surr2).mean() |
|
|
| # Value loss |
| value_loss = F.mse_loss(value_estimates, returns) |
|
|
| # Entropy loss |
| entropy = -(new_log_probs * torch.exp(new_log_probs)).mean() |
|
|
| # Total loss |
| total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy |
| return total_loss |
|
|
| # Tree of Thought Components |
|
|
| class ThoughtNode: |
| def __init__(self, name): |
| self.name = name |
| self.children = [] |
| self.parent = None |
|
|
| def add_child(self, child_node): |
| child_node.parent = self |
| self.children.append(child_node) |
|
|
| # Function to build the Tree of Thought from your detailed structure |
| def build_tree_of_thought(): |
| # Create the root node |
| root = ThoughtNode('Problem-Solving Process') |
|
|
| # Level 1 nodes |
| problem_identification = ThoughtNode('Problem Identification') |
| problem_analysis = ThoughtNode('Problem Analysis') |
| solution_generation = ThoughtNode('Solution Generation') |
| implementation = ThoughtNode('Implementation') |
| evaluation_adjustment = ThoughtNode('Evaluation and Adjustment') |
|
|
| root.add_child(problem_identification) |
| root.add_child(problem_analysis) |
| root.add_child(solution_generation) |
| root.add_child(implementation) |
| root.add_child(evaluation_adjustment) |
|
|
| # Problem Identification children |
| B1 = ThoughtNode('Define the Problem') |
| B2 = ThoughtNode('Identify Stakeholders') |
| B3 = ThoughtNode('Determine Constraints') |
| B4 = ThoughtNode('Recognize Problem Type') |
| B5 = ThoughtNode('Historical Context') |
| problem_identification.add_child(B1) |
| problem_identification.add_child(B2) |
| problem_identification.add_child(B3) |
| problem_identification.add_child(B4) |
| problem_identification.add_child(B5) |
|
|
| # Define the Problem children |
| B1a = ThoughtNode('Problem Statement Formulation') |
| B1b = ThoughtNode('Scope Definition') |
| B1c = ThoughtNode('Objective Setting') |
| B1.add_child(B1a) |
| B1.add_child(B1b) |
| B1.add_child(B1c) |
|
|
| # Identify Stakeholders children |
| B2a = ThoughtNode('Stakeholder Mapping') |
| B2b = ThoughtNode('Interest and Influence Analysis') |
| B2c = ThoughtNode('Engagement Strategy') |
| B2.add_child(B2a) |
| B2.add_child(B2b) |
| B2.add_child(B2c) |
|
|
| # Determine Constraints children |
| B3a = ThoughtNode('Resource Limitations') |
| B3b = ThoughtNode('Time Constraints') |
| B3c = ThoughtNode('Legal and Regulatory Constraints') |
| B3.add_child(B3a) |
| B3.add_child(B3b) |
| B3.add_child(B3c) |
|
|
| # Recognize Problem Type children |
| B4a = ThoughtNode('Simple vs Complex') |
| B4b = ThoughtNode('Known vs Unknown') |
| B4c = ThoughtNode('Tame vs Wicked Problems') |
| B4.add_child(B4a) |
| B4.add_child(B4b) |
| B4.add_child(B4c) |
|
|
| # Historical Context children |
| B5a = ThoughtNode('Previous Attempts') |
| B5b = ThoughtNode('Lessons Learned') |
| B5c = ThoughtNode('Environmental Factors') |
| B5.add_child(B5a) |
| B5.add_child(B5b) |
| B5.add_child(B5c) |
|
|
| # Problem Analysis children |
| C1 = ThoughtNode('Root Cause Analysis') |
| C2 = ThoughtNode('System Mapping') |
| C3 = ThoughtNode('Data Collection') |
| C4 = ThoughtNode('Impact Assessment') |
| C5 = ThoughtNode('Theoretical Framework') |
| problem_analysis.add_child(C1) |
| problem_analysis.add_child(C2) |
| problem_analysis.add_child(C3) |
| problem_analysis.add_child(C4) |
| problem_analysis.add_child(C5) |
|
|
| # Root Cause Analysis children |
| C1a = ThoughtNode('5 Whys Technique') |
| C1b = ThoughtNode('Fishbone Diagram') |
| C1c = ThoughtNode('Pareto Analysis') |
| C1.add_child(C1a) |
| C1.add_child(C1b) |
| C1.add_child(C1c) |
|
|
| # System Mapping children |
| C2a = ThoughtNode('Causal Loop Diagrams') |
| C2b = ThoughtNode('Stock and Flow Models') |
| C2c = ThoughtNode('Network Analysis') |
| C2.add_child(C2a) |
| C2.add_child(C2b) |
| C2.add_child(C2c) |
|
|
| # Data Collection children |
| C3a = ThoughtNode('Quantitative Data') |
| C3b = ThoughtNode('Qualitative Data') |
| C3c = ThoughtNode('Data Validation') |
| C3.add_child(C3a) |
| C3.add_child(C3b) |
| C3.add_child(C3c) |
|
|
| # Quantitative Data children |
| C3a1 = ThoughtNode('Surveys and Questionnaires') |
| C3a2 = ThoughtNode('Experimental Data') |
| C3a3 = ThoughtNode('Big Data Analytics') |
| C3a.add_child(C3a1) |
| C3a.add_child(C3a2) |
| C3a.add_child(C3a3) |
|
|
| # Qualitative Data children |
| C3b1 = ThoughtNode('Interviews') |
| C3b2 = ThoughtNode('Focus Groups') |
| C3b3 = ThoughtNode('Observational Studies') |
| C3b.add_child(C3b1) |
| C3b.add_child(C3b2) |
| C3b.add_child(C3b3) |
|
|
| # Data Validation children |
| C3c1 = ThoughtNode('Statistical Validation') |
| C3c2 = ThoughtNode('Cross-Validation') |
| C3c3 = ThoughtNode('Expert Review') |
| C3c.add_child(C3c1) |
| C3c.add_child(C3c2) |
| C3c.add_child(C3c3) |
|
|
| # Impact Assessment children |
| C4a = ThoughtNode('Environmental Impact') |
| C4b = ThoughtNode('Social Impact') |
| C4c = ThoughtNode('Economic Impact') |
| C4.add_child(C4a) |
| C4.add_child(C4b) |
| C4.add_child(C4c) |
|
|
| # Theoretical Framework children |
| C5a = ThoughtNode('Literature Review') |
| C5b = ThoughtNode('Conceptual Modeling') |
| C5c = ThoughtNode('Hypothesis Formation') |
| C5.add_child(C5a) |
| C5.add_child(C5b) |
| C5.add_child(C5c) |
|
|
| # Solution Generation children |
| D1 = ThoughtNode('Creative Problem Solving') |
| D2 = ThoughtNode('Analytical Approach') |
| D3 = ThoughtNode('Mathematical Computation') |
| D4 = ThoughtNode('Decision Making') |
| solution_generation.add_child(D1) |
| solution_generation.add_child(D2) |
| solution_generation.add_child(D3) |
| solution_generation.add_child(D4) |
|
|
| # Action Planning, Resource Allocation, Change Management children (implementation phase) |
| E1 = ThoughtNode('Action Planning') |
| E2 = ThoughtNode('Resource Allocation') |
| E3 = ThoughtNode('Change Management') |
| implementation.add_child(E1) |
| implementation.add_child(E2) |
| implementation.add_child(E3) |
|
|
| # Verification, Performance Metrics, Feedback Loops, Continuous Improvement children (evaluation phase) |
| F1 = ThoughtNode('Verification') |
| F2 = ThoughtNode('Performance Metrics') |
| F3 = ThoughtNode('Feedback Loops') |
| F4 = ThoughtNode('Continuous Improvement') |
| evaluation_adjustment.add_child(F1) |
| evaluation_adjustment.add_child(F2) |
| evaluation_adjustment.add_child(F3) |
| evaluation_adjustment.add_child(F4) |
|
|
| # Cross-Cutting Considerations children |
| G = ThoughtNode('Cross-Cutting Considerations') |
| root.add_child(G) |
|
|
| # Cross-Cutting Considerations children |
| G1 = ThoughtNode('Ethical Framework') |
| G2 = ThoughtNode('Stakeholder Management') |
| G3 = ThoughtNode('Interdisciplinary Connections') |
| G4 = ThoughtNode('Technological Integration') |
| G5 = ThoughtNode('Emotional Intelligence') |
| G6 = ThoughtNode('Collaborative Problem Solving') |
| G7 = ThoughtNode('Computational Considerations') # Assuming H was intended as G7 |
| G8 = ThoughtNode('Order of Operations') # Assuming I was intended as G8 |
| G9 = ThoughtNode('Critical Thinking') # Assuming J was intended as G9 |
| G10 = ThoughtNode('Future Perspective') # Assuming K was intended as G10 |
| G11 = ThoughtNode('Learning and Adaptation') # Assuming L was intended as G11 |
| G.add_child(G1) |
| G.add_child(G2) |
| G.add_child(G3) |
| G.add_child(G4) |
| G.add_child(G5) |
| G.add_child(G6) |
| G.add_child(G7) |
| G.add_child(G8) |
| G.add_child(G9) |
| G.add_child(G10) |
| G.add_child(G11) |
|
|
| # Ethical Framework children |
| G1a = ThoughtNode('Value-based Decision Making') |
| G1b = ThoughtNode('Long-term Consequences') |
| G1.add_child(G1a) |
| G1.add_child(G1b) |
|
|
| # Value-based Decision Making children |
| G1a1 = ThoughtNode('Ethical Theories Application') |
| G1a2 = ThoughtNode('Moral Dilemma Resolution') |
| G1a.add_child(G1a1) |
| G1a.add_child(G1a2) |
|
|
| # Long-term Consequences children |
| G1b1 = ThoughtNode('Sustainability Assessment') |
| G1b2 = ThoughtNode('Intergenerational Impact') |
| G1b.add_child(G1b1) |
| G1b.add_child(G1b2) |
|
|
| # Stakeholder Management children |
| G2a = ThoughtNode('Direct Stakeholders') |
| G2b = ThoughtNode('Indirect Stakeholders') |
| G2c = ThoughtNode('Conflicting Interests') |
| G2.add_child(G2a) |
| G2.add_child(G2b) |
| G2.add_child(G2c) |
|
|
| # Conflicting Interests children |
| G2c1 = ThoughtNode('Negotiation Strategies') |
| G2c2 = ThoughtNode('Conflict Resolution Techniques') |
| G2c.add_child(G2c1) |
| G2c.add_child(G2c2) |
|
|
| # Interdisciplinary Connections children |
| G3a = ThoughtNode('Related Fields') |
| G3b = ThoughtNode('Cross-disciplinary Impact') |
| G3.add_child(G3a) |
| G3.add_child(G3b) |
|
|
| # Related Fields children |
| G3a1 = ThoughtNode('Cross-domain Knowledge Transfer') |
| G3a2 = ThoughtNode('Interdisciplinary Collaboration') |
| G3a.add_child(G3a1) |
| G3a.add_child(G3a2) |
|
|
| # Cross-disciplinary Impact children |
| G3b1 = ThoughtNode('Synergy Identification') |
| G3b2 = ThoughtNode('Holistic Impact Assessment') |
| G3b.add_child(G3b1) |
| G3b.add_child(G3b2) |
|
|
| # Technological Integration children |
| G4a = ThoughtNode('AI-assisted Problem Solving') |
| G4b = ThoughtNode('Data-driven Insights') |
| G4c = ThoughtNode('Digital Collaboration Tools') |
| G4.add_child(G4a) |
| G4.add_child(G4b) |
| G4.add_child(G4c) |
|
|
| # AI-assisted Problem Solving children |
| G4a1 = ThoughtNode('Machine Learning Models') |
| G4a2 = ThoughtNode('Natural Language Processing') |
| G4a.add_child(G4a1) |
| G4a.add_child(G4a2) |
|
|
| # Data-driven Insights children |
| G4b1 = ThoughtNode('Big Data Analytics') |
| G4b2 = ThoughtNode('Predictive Modeling') |
| G4b.add_child(G4b1) |
| G4b.add_child(G4b2) |
|
|
| # Digital Collaboration Tools children |
| G4c1 = ThoughtNode('Project Management Platforms') |
| G4c2 = ThoughtNode('Virtual Reality Collaboration') |
| G4c.add_child(G4c1) |
| G4c.add_child(G4c2) |
|
|
| # Emotional Intelligence children |
| G5a = ThoughtNode('Self-Awareness') |
| G5b = ThoughtNode('Empathy') |
| G5c = ThoughtNode('Stress Management') |
| G5.add_child(G5a) |
| G5.add_child(G5b) |
| G5.add_child(G5c) |
|
|
| # Self-Awareness children |
| G5a1 = ThoughtNode('Emotional Recognition') |
| G5a2 = ThoughtNode('Personal Bias Identification') |
| G5a.add_child(G5a1) |
| G5a.add_child(G5a2) |
|
|
| # Empathy children |
| G5b1 = ThoughtNode('Perspective Taking') |
| G5b2 = ThoughtNode('Active Listening') |
| G5b.add_child(G5b1) |
| G5b.add_child(G5b2) |
|
|
| # Stress Management children |
| G5c1 = ThoughtNode('Mindfulness Techniques') |
| G5c2 = ThoughtNode('Resilience Building') |
| G5c.add_child(G5c1) |
| G5c.add_child(G5c2) |
|
|
| # Collaborative Problem Solving children |
| G6a = ThoughtNode('Team Dynamics') |
| G6b = ThoughtNode('Communication Strategies') |
| G6c = ThoughtNode('Conflict Resolution') |
| G6.add_child(G6a) |
| G6.add_child(G6b) |
| G6.add_child(G6c) |
|
|
| # Team Dynamics children |
| G6a1 = ThoughtNode('Team Formation Strategies') |
| G6a2 = ThoughtNode('Role Assignment') |
| G6a.add_child(G6a1) |
| G6a.add_child(G6a2) |
|
|
| # Communication Strategies children |
| G6b1 = ThoughtNode('Clear Messaging') |
| G6b2 = ThoughtNode('Feedback Mechanisms') |
| G6b.add_child(G6b1) |
| G6b.add_child(G6b2) |
|
|
| # Conflict Resolution children |
| G6c1 = ThoughtNode('Mediation Techniques') |
| G6c2 = ThoughtNode('Consensus Building') |
| G6c.add_child(G6c1) |
| G6c.add_child(G6c2) |
|
|
| # Computational Considerations children |
| G7a = ThoughtNode('CPU Operations') |
| G7b = ThoughtNode('GPU Parallelization') |
| G7c = ThoughtNode('Floating-Point Precision') |
| G7.add_child(G7a) |
| G7.add_child(G7b) |
| G7.add_child(G7c) |
|
|
| # CPU Operations children |
| G7a1 = ThoughtNode('Instruction Set Architecture') |
| G7a2 = ThoughtNode('Pipelining and Parallelism') |
| G7a.add_child(G7a1) |
| G7a.add_child(G7a2) |
|
|
| # GPU Parallelization children |
| G7b1 = ThoughtNode('CUDA Programming') |
| G7b2 = ThoughtNode('OpenCL Framework') |
| G7b.add_child(G7b1) |
| G7b.add_child(G7b2) |
|
|
| # Floating-Point Precision children |
| G7c1 = ThoughtNode('IEEE 754 Standard') |
| G7c2 = ThoughtNode('Error Propagation Analysis') |
| G7c.add_child(G7c1) |
| G7c.add_child(G7c2) |
|
|
| # Order of Operations children |
| G8a = ThoughtNode('Parentheses') |
| G8b = ThoughtNode('Exponents') |
| G8c = ThoughtNode('Multiplication and Division') |
| G8d = ThoughtNode('Addition and Subtraction') |
| G8.add_child(G8a) |
| G8.add_child(G8b) |
| G8.add_child(G8c) |
| G8.add_child(G8d) |
|
|
| # Critical Thinking children |
| G9a = ThoughtNode('Assumptions Questioning') |
| G9b = ThoughtNode('Bias Recognition') |
| G9.add_child(G9a) |
| G9.add_child(G9b) |
|
|
| # Assumptions Questioning children |
| G9a1 = ThoughtNode('Socratic Questioning') |
| G9a2 = ThoughtNode('Devil\'s Advocate Approach') |
| G9a.add_child(G9a1) |
| G9a.add_child(G9a2) |
|
|
| # Bias Recognition children |
| G9b1 = ThoughtNode('Cognitive Bias Identification') |
| G9b2 = ThoughtNode('Debiasing Techniques') |
| G9b.add_child(G9b1) |
| G9b.add_child(G9b2) |
|
|
| # Future Perspective children |
| G10a = ThoughtNode('Short-term Projections') |
| G10b = ThoughtNode('Long-term Scenarios') |
| G10c = ThoughtNode('Potential Impacts') |
| G10.add_child(G10a) |
| G10.add_child(G10b) |
| G10.add_child(G10c) |
|
|
| # Short-term Projections children |
| G10a1 = ThoughtNode('Trend Analysis') |
| G10a2 = ThoughtNode('Scenario Planning') |
| G10a.add_child(G10a1) |
| G10a.add_child(G10a2) |
|
|
| # Long-term Scenarios children |
| G10b1 = ThoughtNode('Futures Wheel') |
| G10b2 = ThoughtNode('Backcasting') |
| G10b.add_child(G10b1) |
| G10b.add_child(G10b2) |
|
|
| # Potential Impacts children |
| G10c1 = ThoughtNode('Risk Assessment') |
| G10c2 = ThoughtNode('Opportunity Identification') |
| G10c.add_child(G10c1) |
| G10c.add_child(G10c2) |
|
|
| # Learning and Adaptation children |
| G11a = ThoughtNode('Reflective Practice') |
| G11b = ThoughtNode('Knowledge Transfer') |
| G11c = ThoughtNode('Adaptive Problem Solving') |
| G11.add_child(G11a) |
| G11.add_child(G11b) |
| G11.add_child(G11c) |
|
|
| # Reflective Practice children |
| G11a1 = ThoughtNode('After Action Review') |
| G11a2 = ThoughtNode('Learning Journals') |
| G11a.add_child(G11a1) |
| G11a.add_child(G11a2) |
|
|
| # Knowledge Transfer children |
| G11b1 = ThoughtNode('Best Practice Documentation') |
| G11b2 = ThoughtNode('Mentoring Programs') |
| G11b.add_child(G11b1) |
| G11b.add_child(G11b2) |
|
|
| # Adaptive Problem Solving children |
| G11c1 = ThoughtNode('Iterative Approaches') |
| G11c2 = ThoughtNode('Flexibility in Methodology') |
| G11c.add_child(G11c1) |
| G11c.add_child(G11c2) |
|
|
| return root |
|
|
| def traverse_tree(node, action_list): |
| if node.name not in action_list: |
| action_list.append(node.name) |
| for child in node.children: |
| traverse_tree(child, action_list) |
|
|
|
|
|
|
| def infer(query, world_model_components, root_thought_node, tokenizer, max_length=2000, inference_mode='world_model', beam_size=5, n_tokens_predict=3, mcts_iterations=10, exploration_constant=1.414): |
|
|
|
|
| """ |
| Perform inference given a query, utilizing the Tree of Thought and MCTS with multi-token beam search. |
| |
| Args: |
| query (str): The input query or prompt. |
| world_model_components (tuple): Tuple containing the model components. |
| root_thought_node (ThoughtNode): The root node of the Tree of Thought. |
| tokenizer (transformers.PreTrainedTokenizer): The tokenizer used. |
| max_length (int): Maximum length for the generated sequence. |
| inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought') |
| beam_size (int): Size of the beam for beam search |
| n_tokens_predict (int): Number of tokens to predict at each step |
| |
| Returns: |
| List[str] or str: The sequence of actions (thoughts) selected or generated text. |
| """ |
| representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components |
|
|
| # Tokenize and encode the query |
| input_ids = tokenizer.encode(query, return_tensors='pt').to(device) |
| attention_mask = (input_ids != tokenizer.pad_token_id).long() |
|
|
| if inference_mode == 'without_world_model': |
| # Directly use the transformer model to generate text with beam search |
| with torch.no_grad(): |
| generated_sequences = model_transformer.generate_with_beam_search( |
| src=input_ids, |
| tokenizer=tokenizer, |
| beam_size=beam_size, |
| max_length=max_length, |
| n_tokens_predict=n_tokens_predict, |
| temperature=args.temperature |
| ) |
| best_sequence, best_score = generated_sequences[0] |
| generated_text = tokenizer.decode(best_sequence[0], skip_special_tokens=True) |
| return generated_text |
|
|
| else: |
| # Use the world model components |
| with torch.no_grad(): |
| transformer_output = model_transformer(input_ids, input_ids) |
| # Get the initial state representation |
| initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim) |
| initial_representation = initial_representation[:, -1, :].unsqueeze(1) # Shape: (batch_size=1, 1, state_dim) |
| initial_state = State( |
| representation=initial_representation, |
| dynamics_network=dynamics_network, |
| action_encoder=action_encoder, |
| thought_node=root_thought_node |
| ) |
| if inference_mode == 'world_model_tree_of_thought': |
| # Use MCTS with Tree of Thought and multi-token beam search |
| mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=mcts_iterations, exploration_constant=exploration_constant) |
|
|
| current_state = initial_state |
| thought_sequence = [] |
|
|
| for _ in range(max_length |
| best_actions = mcts.search_with_beam(current_state) |
|
|
| thought_sequence.extend(best_actions) |
|
|
| # Apply the best actions to get the next state |
| for action in best_actions: |
| current_state = current_state.apply_action(action) |
|
|
| # Check if we've reached a leaf node (no further actions) |
| if len(current_state.thought_node.children) == 0: |
| break |
|
|
| return thought_sequence |
| else: |
| # Use the world model without Tree of Thought, but with multi-token beam search |
| beam = [(initial_state, 0.0, torch.zeros(1, device=device), torch.zeros(1, device=device))] # (state, score, cum_entropy, cum_variance) |
|
|
| for _ in range(max_length |
| all_candidates = [] |
| for state, score, cum_entropy, cum_variance in beam: |
| policy_logits, _ = prediction_network(state.representation) |
| probs = F.softmax(policy_logits / args.temperature, dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) |
| variance = torch.var(probs, dim=-1) |
|
|
| topk_probs, topk_indices = torch.topk(probs, k=beam_size, dim=-1) |
|
|
| for i in range(beam_size ** n_tokens_predict): |
| indices = [i // (beam_size ** j) % beam_size for j in range(n_tokens_predict)] |
| new_actions = [index_to_action[topk_indices[0, j, indices[j]].item()] for j in range(n_tokens_predict)] |
| new_score = score + torch.sum(torch.log(topk_probs[0, range(n_tokens_predict), indices])) |
| new_entropy = cum_entropy + torch.sum(entropy[0, indices]) |
| new_variance = cum_variance + torch.sum(variance[0, indices]) |
|
|
| new_state = state |
| for action in new_actions: |
| new_state = new_state.apply_action(action) |
|
|
| all_candidates.append((new_state, new_score, new_entropy, new_variance, new_actions)) |
|
|
| # Select top beam_size candidates |
| beam = sorted(all_candidates, key=lambda x: x[1] - 0.1 * x[2] + 0.05 * x[3], reverse=True)[:beam_size] |
|
|
| # Accumulate actions |
| if not thought_sequence: |
| thought_sequence = [b[4] for b in beam] |
| else: |
| for i, b in enumerate(beam): |
| thought_sequence[i].extend(b[4]) |
|
|
| # Return the top sequence |
| return thought_sequence[0] |
|
|
|
|
| def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
| representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components |
| representation_network.train() |
| dynamics_network.train() |
| prediction_network.train() |
| action_encoder.train() |
| ppo_agent.policy_network.train() |
|
|
| total_loss = 0.0 |
| optimizer.zero_grad() |
| print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
|
|
| for i, batch in enumerate(train_loader): |
| print(f"Processing batch {i+1}/{len(train_loader)}...") |
|
|
| # Move batches to the device |
| src_batch = batch['input_ids'].to(device) |
| tgt_batch = batch['labels'].to(device) |
|
|
| with torch.amp.autocast(device_type='cuda'): |
| print("Forward pass through Transformer (frozen)...") |
| with torch.no_grad(): |
| transformer_output = model_transformer(src_batch, tgt_batch[:, :-1]) |
|
|
| # World Model - Representation |
| state_representation = representation_network(transformer_output) |
|
|
| # For simplicity, let's assume true actions are provided (e.g., next tokens) |
| true_actions = tgt_batch[:, :-1] |
| action_sequences = true_actions |
|
|
| # Get action embeddings |
| action_embeddings = action_encoder(action_sequences) |
|
|
| # Apply dynamics network |
| predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
|
|
| # Prediction Network - Policy logits and value |
| policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
|
|
| # Define true_policy and true_value as placeholders on the GPU |
| true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
| true_value = torch.zeros_like(value_estimates).to(device) |
|
|
| # Compute individual losses |
| ppo_loss = ppo_agent.compute_loss( |
| state_representation, |
| torch.zeros_like(true_actions, dtype=torch.float32).to(device), |
| true_actions, |
| torch.zeros_like(value_estimates, dtype=torch.float32).to(device), |
| torch.zeros_like(value_estimates, dtype=torch.float32).to(device) |
| ) |
|
|
| info_nce = InfoNCE_Loss()( |
| state_representation.reshape(-1, state_dim), |
| F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True) |
| ) |
|
|
|
|
| covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
| dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
|
|
| perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
| thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
|
|
| pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
| action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
|
|
| mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
| etv = ExpectedThoughtValueLoss()(mcts_best_values) |
|
|
| visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
| exploration = ExplorationRegularization()(visit_counts) |
|
|
| old_policy = F.softmax(policy_logits.detach(), dim=-1) |
| new_policy = F.softmax(policy_logits, dim=-1) |
| kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
|
|
| # Total Loss |
| loss = ( |
| ppo_loss + |
| info_nce + |
| covariance + |
| dynamics_loss + |
| thought_loss + |
| pv_loss + |
| action_diversity + |
| etv + |
| exploration + |
| kl_loss |
| ) |
| loss = loss / args.accumulation_steps |
|
|
| print("Backward pass...") |
| scaler.scale(loss).backward() |
|
|
| if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| print("Gradient clipping...") |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| [param for group in optimizer.param_groups for param in group['params']], |
| args.max_grad_norm |
| ) |
|
|
| print("Optimizer step...") |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| print("Zeroing gradients...") |
| optimizer.zero_grad() |
|
|
| print("Updating learning rate...") |
| scheduler.step() |
|
|
| total_loss += loss.item() * args.accumulation_steps |
|
|
| # Print individual losses and total loss for this batch |
| print(f"Batch {i+1} completed. Losses:") |
| print(f" PPO Loss: {ppo_loss.item():.4f}") |
| print(f" InfoNCE Loss: {info_nce.item():.4f}") |
| print(f" Covariance Loss: {covariance.item():.4f}") |
| print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
| print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
| print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
| print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
| print(f" Expected Thought Value Loss: {etv.item():.4f}") |
| print(f" Exploration Loss: {exploration.item():.4f}") |
| print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
| print(f" Total Loss: {loss.item():.4f}") |
|
|
| avg_loss = total_loss / len(train_loader) |
| print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
| return avg_loss |
|
|
| def train_epoch_language_model(model, train_loader, optimizer, scheduler, scaler, args): |
| model.train() |
| total_loss = 0.0 |
| optimizer.zero_grad() |
| print(f"Starting Language Model training epoch with {len(train_loader)} batches...") |
|
|
| for i, batch in enumerate(train_loader): |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| with autocast(): |
| outputs = model(input_ids, input_ids) |
| logits = outputs.view(-1, outputs.size(-1)) |
| labels = labels.view(-1) |
| loss = F.cross_entropy(logits, labels, ignore_index=model.embedding.padding_idx) |
| loss = loss / args.accumulation_steps |
|
|
| scaler.scale(loss).backward() |
|
|
| if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| [param for group in optimizer.param_groups for param in group['params']], |
| args.max_grad_norm |
| ) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| scheduler.step() |
|
|
| total_loss += loss.item() * args.accumulation_steps |
| print(f"Batch {i + 1} completed. Current loss: {loss.item():.4f}") |
|
|
| avg_loss = total_loss / len(train_loader) |
| print(f"Language Model training epoch completed. Average loss: {avg_loss:.4f}") |
| return avg_loss |
|
|
|
|
| def train_custom_data_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim): |
| representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components |
| representation_network.train() |
| dynamics_network.train() |
| prediction_network.train() |
| action_encoder.train() |
| ppo_agent.policy_network.train() |
|
|
| total_loss = 0.0 |
| optimizer.zero_grad() |
| print(f"Starting World Model training epoch with {len(train_loader)} batches...") |
|
|
| for i, batch in enumerate(train_loader): |
| print(f"Processing batch {i+1}/{len(train_loader)}...") |
|
|
| # Move batches to the device |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| episode_reward = batch['episode_reward'].to(device) |
| loss_value = batch['loss'].to(device) |
| cosine_similarity = batch['cosine_similarity'].to(device) |
| rag_performance = batch['rag_performance'].to(device) |
| ranking_model_performance = batch['ranking_model_performance'].to(device) |
|
|
| with torch.amp.autocast(device_type='cuda'): |
| print("Forward pass through Transformer (frozen)...") |
| with torch.no_grad(): |
| transformer_output = model_transformer(input_ids, input_ids) |
|
|
| # World Model - Representation |
| state_representation = representation_network(transformer_output) |
| print(f"State representation shape: {state_representation.shape}") |
|
|
| # For simplicity, let's assume true actions are provided (e.g., next tokens) |
| true_actions = input_ids[:, 1:] # Shift input_ids by 1 to get next tokens |
| print(f"True actions shape: {true_actions.shape}") |
| action_sequences = true_actions |
|
|
| # Get action embeddings |
| action_embeddings = action_encoder(action_sequences) |
| print(f"Action embeddings shape: {action_embeddings.shape}") |
|
|
| # Ensure state_representation and action_embeddings have the same sequence length |
| min_seq_len = min(state_representation.size(1), action_embeddings.size(1)) |
| state_representation = state_representation[:, :min_seq_len, :] |
| action_embeddings = action_embeddings[:, :min_seq_len, :] |
|
|
| print(f"Adjusted state representation shape: {state_representation.shape}") |
| print(f"Adjusted action embeddings shape: {action_embeddings.shape}") |
|
|
| # Apply dynamics network |
| predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) |
| print(f"Predicted next state batch shape: {predicted_next_state_batch.shape}") |
|
|
| # Prediction Network - Policy logits and value |
| policy_logits, value_estimates = prediction_network(predicted_next_state_batch) |
|
|
| # Adjust true_actions to match the sequence length |
| true_actions = true_actions[:, :min_seq_len] |
|
|
| # Define true_policy and true_value |
| true_policy = F.one_hot(true_actions, num_classes=input_dim).float() |
| true_value = episode_reward.unsqueeze(1).expand(-1, min_seq_len) # Expand to match sequence length |
|
|
| # Compute individual losses |
| info_nce = InfoNCE_Loss()( |
| state_representation.reshape(-1, state_dim), |
| F.dropout(state_representation.reshape(-1, state_dim), p=0.1, training=True) |
| ) |
|
|
| covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1))) |
| dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch) |
|
|
| perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01 |
| thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state) |
|
|
| pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1)) |
| action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim)) |
|
|
| mcts_best_values = torch.zeros(true_actions.size(0)).to(device) |
| etv = ExpectedThoughtValueLoss()(mcts_best_values) |
|
|
| visit_counts = torch.ones(true_actions.size(0), policy_logits.size(-1)).to(device) |
| exploration = ExplorationRegularization()(visit_counts) |
|
|
| old_policy = F.softmax(policy_logits.detach(), dim=-1) |
| new_policy = F.softmax(policy_logits, dim=-1) |
| kl_loss = KL_DivergenceLoss()(old_policy, new_policy) |
|
|
| # Compute mean value estimates over the sequence length |
| value_estimates_mean = value_estimates.squeeze(-1).mean(dim=1) # Shape: [batch_size] |
|
|
| # Add new loss components |
| rag_loss = F.mse_loss(value_estimates_mean, rag_performance) |
| ranking_loss = F.mse_loss(value_estimates_mean, ranking_model_performance) |
| cosine_similarity_loss = 1 - cosine_similarity.mean() # Maximize cosine similarity |
|
|
| # Total Loss |
| loss = ( |
| info_nce + |
| covariance + |
| dynamics_loss + |
| thought_loss + |
| pv_loss + |
| action_diversity + |
| etv + |
| exploration + |
| kl_loss + |
| rag_loss + |
| ranking_loss + |
| cosine_similarity_loss |
| ) |
| loss = loss / args.accumulation_steps |
|
|
| print("Backward pass...") |
| scaler.scale(loss).backward() |
|
|
| if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader): |
| print("Gradient clipping...") |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| [param for group in optimizer.param_groups for param in group['params']], |
| args.max_grad_norm |
| ) |
|
|
| print("Optimizer step...") |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| print("Zeroing gradients...") |
| optimizer.zero_grad() |
|
|
| print("Updating learning rate...") |
| scheduler.step() |
|
|
| # Print individual losses and total loss for this batch |
| print(f"Batch {i+1} completed. Losses:") |
| print(f" InfoNCE Loss: {info_nce.item():.4f}") |
| print(f" Covariance Loss: {covariance.item():.4f}") |
| print(f" Dynamics Loss: {dynamics_loss.item():.4f}") |
| print(f" Thought Consistency Loss: {thought_loss.item():.4f}") |
| print(f" Policy-Value Loss: {pv_loss.item():.4f}") |
| print(f" Action Diversity Loss: {action_diversity.item():.4f}") |
| print(f" Expected Thought Value Loss: {etv.item():.4f}") |
| print(f" Exploration Loss: {exploration.item():.4f}") |
| print(f" KL Divergence Loss: {kl_loss.item():.4f}") |
| print(f" RAG Loss: {rag_loss.item():.4f}") |
| print(f" Ranking Loss: {ranking_loss.item():.4f}") |
| print(f" Cosine Similarity Loss: {cosine_similarity_loss.item():.4f}") |
| print(f" Total Loss: {loss.item():.4f}") |
|
|
| avg_loss = total_loss / len(train_loader) |
| print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}") |
| return avg_loss |
|
|
|
|
| def main(): |
| args = parse_args() |
| print("Arguments parsed successfully.") |
|
|
| # Create save directory |
| os.makedirs(args.save_dir, exist_ok=True) |
| print(f"Save directory created: {args.save_dir}") |
|
|
| # Load tokenizer |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| print("Tokenizer loaded successfully.") |
|
|
| # Define padding_idx and input dimension based on tokenizer |
| padding_idx = tokenizer.pad_token_id |
| input_dim = len(tokenizer) |
|
|
|
|
| # Initialize the Transformer model on GPU |
| print("Initializing Transformer model...") |
| model_transformer = Transformer( |
| input_dim=input_dim, |
| d_model=128, |
| num_heads=4, |
| num_layers=4, |
| d_ff=256, |
| num_experts=2, |
| output_dim=input_dim, |
| dropout=0.1, |
| top_k=2 |
| ).to(device) |
| model_transformer.train() |
| print("Transformer model initialized on device.") |
|
|
| # Define model parameters (adjusted for speed) |
| d_model = 32 |
| state_dim = 32 |
| action_dim = d_model |
| hidden_dim = 64 |
| vocab_dim = input_dim |
| embed_dim = d_model |
|
|
| # Define World Model components |
| representation_network = RepresentationNetwork(vocab_dim, d_model, state_dim).to(device) |
| dynamics_network = DynamicsNetwork(state_dim, action_dim, hidden_dim).to(device) |
| prediction_network = PredictionNetwork(state_dim, input_dim, 1).to(device) |
| action_encoder = ActionEncoder(input_dim, action_dim).to(device) |
|
|
| # Initialize PPO Agent |
| ppo_agent = PPOAgent( |
| policy_network=prediction_network, |
| optimizer=optim.AdamW(prediction_network.parameters(), lr=args.learning_rate), |
| clip_epsilon=0.2, |
| entropy_coef=0.01, |
| value_coef=0.5 |
| ) |
|
|
| # Bundle World Model components |
| world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer) |
|
|
| print(f"Current mode: {args.mode}") |
| if args.mode == 'train': |
| print("Loading and preprocessing data...") |
| if args.use_custom_data: |
| custom_data = load_custom_data_from_files(args.custom_data_paths) |
| processed_data = preprocess_custom_data(custom_data) |
| train_loader, eval_loader = load_custom_data(args, tokenizer, processed_data) |
| print("Custom data loaded and preprocessed successfully.") |
| else: |
| train_loader, eval_loader = load_data(args, tokenizer) |
| print("Default data loaded and preprocessed successfully.") |
|
|
| # Optimizer and Scheduler |
| optimizer = optim.AdamW( |
| list(representation_network.parameters()) + |
| list(dynamics_network.parameters()) + |
| list(prediction_network.parameters()) + |
| list(action_encoder.parameters()), |
| lr=args.learning_rate, weight_decay=args.weight_decay |
| ) if args.train_mode == 'world_model' else optim.AdamW(model_transformer.parameters(), lr=args.learning_rate) |
| scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs) |
| scaler = GradScaler() |
|
|
| print(f"Starting {args.train_mode} training...") |
|
|
| for epoch in range(args.num_epochs): |
| if args.train_mode == 'world_model': |
| if args.use_custom_data: |
| avg_loss = train_custom_data_epoch_world_model( |
| world_model_components, |
| train_loader, |
| optimizer, |
| scheduler, |
| scaler, |
| args, |
| model_transformer, |
| state_dim, |
| embed_dim, |
| input_dim |
| ) |
| else: |
| avg_loss = train_epoch_world_model( |
| world_model_components, |
| train_loader, |
| optimizer, |
| scheduler, |
| scaler, |
| args, |
| model_transformer, |
| state_dim, |
| embed_dim, |
| input_dim |
| ) |
| else: |
| avg_loss = train_epoch_language_model( |
| model_transformer, |
| train_loader, |
| optimizer, |
| scheduler, |
| scaler, |
| args |
| ) |
|
|
| print(f"{args.train_mode.capitalize()} training epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") |
|
|
| # Save models |
| if args.train_mode == 'world_model': |
| save_all_models(model_transformer, representation_network, dynamics_network, prediction_network, action_encoder, args.save_dir, epoch + 1) |
| print(f"Models saved for epoch {epoch + 1}") |
| else: |
| torch.save(model_transformer.state_dict(), os.path.join(args.save_dir, f'language_model_epoch_{epoch + 1}.pt')) |
| print(f"Language model saved for epoch {epoch + 1}") |
|
|
| print("Training completed.") |
|
|
| elif args.mode == 'inference': |
| print("Entering inference mode...") |
| # Build Tree of Thought if needed |
| print("Building Tree of Thought...") |
| tree_root = build_tree_of_thought() |
| print("Tree of Thought built successfully.") |
|
|
| # Generate action list |
| print("Generating action list...") |
| action_list = [] |
| traverse_tree(tree_root, action_list) |
| print(f"Action list generated. Total actions: {len(action_list)}") |
|
|
| # Create mappings |
| global action_to_index, index_to_action |
| action_to_index = {action: idx for idx, action in enumerate(action_list)} |
| index_to_action = {idx: action for action, idx in action_to_index.items()} |
| action_vocab_size = len(action_list) |
| print(f"Action mappings created. Vocabulary size: {action_vocab_size}") |
|
|
| # Initialize or load models based on the load_model argument |
| if args.load_model: |
| print(f"Loading saved model from {args.load_model}") |
| # Load the saved models |
| model_transformer.load_state_dict(torch.load(os.path.join(args.load_model, 'transformer_model.pt'))) |
| representation_network.load_state_dict(torch.load(os.path.join(args.load_model, 'representation_network.pt'))) |
| dynamics_network.load_state_dict(torch.load(os.path.join(args.load_model, 'dynamics_network.pt'))) |
|
|
| # Load prediction network and adjust its size if necessary |
| saved_state_dict = torch.load(os.path.join(args.load_model, 'prediction_network.pt')) |
| saved_vocab_size = saved_state_dict['policy_head.weight'].size(0) |
| if saved_vocab_size != action_vocab_size: |
| print(f"Adjusting prediction network size from {saved_vocab_size} to {action_vocab_size}") |
| prediction_network = PredictionNetwork(state_dim, saved_vocab_size, 1).to(device) |
| prediction_network.load_state_dict(saved_state_dict) |
| prediction_network.policy_head = nn.Linear(prediction_network.state_dim, action_vocab_size).to(device) |
| else: |
| prediction_network = PredictionNetwork(state_dim, action_vocab_size, 1).to(device) |
| prediction_network.load_state_dict(saved_state_dict) |
|
|
| action_encoder.load_state_dict(torch.load(os.path.join(args.load_model, 'action_encoder.pt'))) |
| else: |
| print("Using newly initialized models") |
|
|
| # Prepare the components |
| world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer) |
|
|
| print("Starting inference loop...") |
| while True: |
| if args.query: |
| query = args.query |
| args.query = None # Reset query for next iteration |
| else: |
| query = input("Please enter your query (or type 'exit' to quit): ") |
| if query.lower() == 'exit': |
| break |
|
|
| print(f"Processing query: {query}") |
| result = infer(query, world_model_components, tree_root, tokenizer, |
| max_length=args.max_length, |
| inference_mode=args.inference_mode, |
| beam_size=args.beam_size, |
| n_tokens_predict=args.n_tokens_predict, |
| mcts_iterations=args.mcts_iterations, |
| exploration_constant=args.mcts_exploration_constant) |
|
|
|
|
| if args.inference_mode == 'without_world_model': |
| print("Generated Text:") |
| print(result) |
| else: |
| print("Generated Thought Sequence:") |
| for thought in result: |
| print(thought) |
|
|
| print("\n") # Add a newline for better readability between queries |
|
|
| print("Inference completed.") |
|
|
| else: |
| print(f"Invalid mode: {args.mode}. Please choose 'train' or 'inference'.") |
| if __name__ == '__main__': |
| sys.argv = [ |
| 'lightbulb_2.py', |
| '--mode', 'inference', |
| '--train_mode', 'world_model', # Set 'world_model' or 'language_model' depending on the training mode |
| '--dataset_name', 'wikitext', # Specify the Hugging Face dataset (e.g., 'wikitext') |
| '--dataset_config', 'wikitext-2-raw-v1', # Use if you need a specific config of the dataset |
| '--num_epochs', '10', |
| '--batch_size', '4', |
| '--accumulation_steps', '1', |
| '--max_grad_norm', '1.0', |
| '--weight_decay', '0.01', |
| '--learning_rate', '1e-4', |
| '--max_length', '512', |
| '--save_dir', './trained_models', |
| # Uncomment the following line to use custom data instead of a Hugging Face dataset |
| #'--use_custom_data', |
| '--custom_data_paths', '/content/drive/MyDrive/lightbulb/knowledge_base.json', |
| '--custom_data_paths', '/content/drive/MyDrive/lightbulb/rag_cache.json', |
| '--custom_data_paths', '/content/drive/MyDrive/lightbulb/llm_training_data/llm_training_data.jsonl' |
| ] |
|
|
| # Parse the arguments and run the main training function |
| args = parse_args() |
|
|
| # Check which data source to use |
| if args.use_custom_data: |
| print("Training with custom data from paths:") |
| for path in args.custom_data_paths: |
| print(f" - {path}") |
| else: |
| print(f"Training with dataset '{args.dataset_name}' from Hugging Face Datasets") |
|
|
| main() |
|
|
|
|