flwi commited on
Commit
0af3de7
·
verified ·
1 Parent(s): c4fa5ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -3
README.md CHANGED
@@ -1,3 +1,50 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Towards a Physics Foundation Model
5
+
6
+ Weights for the Physics Foundation Models presented in [Towards a Physics Foundation Model](https://arxiv.org/abs/2509.13805).
7
+ The weights correspond to the different sizes (S,M,L,XL).
8
+ The exact configs for these models are found in the [github repo](https://github.com/FloWsnr/General-Physics-Transformer), specificly in the ``model_specs`` file.
9
+
10
+ Use the following function or similar to initialize the model and then load the state dicts:
11
+ ```python
12
+ def get_model(model_config: dict):
13
+ """Get the model."""
14
+ transformer_config: dict = model_config["transformer"]
15
+ tokenizer_config: dict = model_config["tokenizer"]
16
+
17
+ if transformer_config["model_size"] == "GPT_S":
18
+ gpt_config = model_specs.GPT_S()
19
+ elif transformer_config["model_size"] == "GPT_M":
20
+ gpt_config = model_specs.GPT_M()
21
+ elif transformer_config["model_size"] == "GPT_L":
22
+ gpt_config = model_specs.GPT_L()
23
+ elif transformer_config["model_size"] == "GPT_XL":
24
+ gpt_config = model_specs.GPT_XL()
25
+ else:
26
+ raise ValueError(f"Invalid model size: {transformer_config['model_size']}")
27
+
28
+ return PhysicsTransformer(
29
+ num_fields=transformer_config["input_channels"],
30
+ hidden_dim=gpt_config.hidden_dim,
31
+ mlp_dim=gpt_config.mlp_dim,
32
+ num_heads=gpt_config.num_heads,
33
+ num_layers=gpt_config.num_layers,
34
+ att_mode=transformer_config.get("att_mode", "full"),
35
+ integrator=transformer_config.get("integrator", "Euler"),
36
+ pos_enc_mode=transformer_config["pos_enc_mode"],
37
+ img_size=model_config["img_size"],
38
+ patch_size=transformer_config["patch_size"],
39
+ use_derivatives=transformer_config["use_derivatives"],
40
+ tokenizer_mode=tokenizer_config["tokenizer_mode"],
41
+ detokenizer_mode=tokenizer_config["detokenizer_mode"],
42
+ tokenizer_overlap=tokenizer_config["tokenizer_overlap"],
43
+ detokenizer_overlap=tokenizer_config["detokenizer_overlap"],
44
+ tokenizer_net_channels=gpt_config.conv_channels,
45
+ detokenizer_net_channels=gpt_config.conv_channels,
46
+ dropout=transformer_config["dropout"],
47
+ stochastic_depth_rate=transformer_config["stochastic_depth_rate"],
48
+ )
49
+ ```
50
+