File size: 2,258 Bytes
0af3de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
license: mit
---
# Towards a Physics Foundation Model

Weights for the Physics Foundation Models presented in [Towards a Physics Foundation Model](https://arxiv.org/abs/2509.13805).
The weights correspond to the different sizes (S,M,L,XL). 
The exact configs for these models are found in the [github repo](https://github.com/FloWsnr/General-Physics-Transformer), specificly in the ``model_specs`` file.

Use the following function or similar to initialize the model and then load the state dicts:
```python
def get_model(model_config: dict):
    """Get the model."""
    transformer_config: dict = model_config["transformer"]
    tokenizer_config: dict = model_config["tokenizer"]

    if transformer_config["model_size"] == "GPT_S":
        gpt_config = model_specs.GPT_S()
    elif transformer_config["model_size"] == "GPT_M":
        gpt_config = model_specs.GPT_M()
    elif transformer_config["model_size"] == "GPT_L":
        gpt_config = model_specs.GPT_L()
    elif transformer_config["model_size"] == "GPT_XL":
        gpt_config = model_specs.GPT_XL()
    else:
        raise ValueError(f"Invalid model size: {transformer_config['model_size']}")

    return PhysicsTransformer(
        num_fields=transformer_config["input_channels"],
        hidden_dim=gpt_config.hidden_dim,
        mlp_dim=gpt_config.mlp_dim,
        num_heads=gpt_config.num_heads,
        num_layers=gpt_config.num_layers,
        att_mode=transformer_config.get("att_mode", "full"),
        integrator=transformer_config.get("integrator", "Euler"),
        pos_enc_mode=transformer_config["pos_enc_mode"],
        img_size=model_config["img_size"],
        patch_size=transformer_config["patch_size"],
        use_derivatives=transformer_config["use_derivatives"],
        tokenizer_mode=tokenizer_config["tokenizer_mode"],
        detokenizer_mode=tokenizer_config["detokenizer_mode"],
        tokenizer_overlap=tokenizer_config["tokenizer_overlap"],
        detokenizer_overlap=tokenizer_config["detokenizer_overlap"],
        tokenizer_net_channels=gpt_config.conv_channels,
        detokenizer_net_channels=gpt_config.conv_channels,
        dropout=transformer_config["dropout"],
        stochastic_depth_rate=transformer_config["stochastic_depth_rate"],
    )
```