Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,50 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
| 4 |
+
# Towards a Physics Foundation Model
|
| 5 |
+
|
| 6 |
+
Weights for the Physics Foundation Models presented in [Towards a Physics Foundation Model](https://arxiv.org/abs/2509.13805).
|
| 7 |
+
The weights correspond to the different sizes (S,M,L,XL).
|
| 8 |
+
The exact configs for these models are found in the [github repo](https://github.com/FloWsnr/General-Physics-Transformer), specificly in the ``model_specs`` file.
|
| 9 |
+
|
| 10 |
+
Use the following function or similar to initialize the model and then load the state dicts:
|
| 11 |
+
```python
|
| 12 |
+
def get_model(model_config: dict):
|
| 13 |
+
"""Get the model."""
|
| 14 |
+
transformer_config: dict = model_config["transformer"]
|
| 15 |
+
tokenizer_config: dict = model_config["tokenizer"]
|
| 16 |
+
|
| 17 |
+
if transformer_config["model_size"] == "GPT_S":
|
| 18 |
+
gpt_config = model_specs.GPT_S()
|
| 19 |
+
elif transformer_config["model_size"] == "GPT_M":
|
| 20 |
+
gpt_config = model_specs.GPT_M()
|
| 21 |
+
elif transformer_config["model_size"] == "GPT_L":
|
| 22 |
+
gpt_config = model_specs.GPT_L()
|
| 23 |
+
elif transformer_config["model_size"] == "GPT_XL":
|
| 24 |
+
gpt_config = model_specs.GPT_XL()
|
| 25 |
+
else:
|
| 26 |
+
raise ValueError(f"Invalid model size: {transformer_config['model_size']}")
|
| 27 |
+
|
| 28 |
+
return PhysicsTransformer(
|
| 29 |
+
num_fields=transformer_config["input_channels"],
|
| 30 |
+
hidden_dim=gpt_config.hidden_dim,
|
| 31 |
+
mlp_dim=gpt_config.mlp_dim,
|
| 32 |
+
num_heads=gpt_config.num_heads,
|
| 33 |
+
num_layers=gpt_config.num_layers,
|
| 34 |
+
att_mode=transformer_config.get("att_mode", "full"),
|
| 35 |
+
integrator=transformer_config.get("integrator", "Euler"),
|
| 36 |
+
pos_enc_mode=transformer_config["pos_enc_mode"],
|
| 37 |
+
img_size=model_config["img_size"],
|
| 38 |
+
patch_size=transformer_config["patch_size"],
|
| 39 |
+
use_derivatives=transformer_config["use_derivatives"],
|
| 40 |
+
tokenizer_mode=tokenizer_config["tokenizer_mode"],
|
| 41 |
+
detokenizer_mode=tokenizer_config["detokenizer_mode"],
|
| 42 |
+
tokenizer_overlap=tokenizer_config["tokenizer_overlap"],
|
| 43 |
+
detokenizer_overlap=tokenizer_config["detokenizer_overlap"],
|
| 44 |
+
tokenizer_net_channels=gpt_config.conv_channels,
|
| 45 |
+
detokenizer_net_channels=gpt_config.conv_channels,
|
| 46 |
+
dropout=transformer_config["dropout"],
|
| 47 |
+
stochastic_depth_rate=transformer_config["stochastic_depth_rate"],
|
| 48 |
+
)
|
| 49 |
+
```
|
| 50 |
+
|