Usage
Contents
Usage#
Here, we highlight the main blocks to do the finite-size scaling.
1. Scaling function#
In jaxfss
, we approximate the scaling function to a neural netowrk.
We provide a simple multi layer perceptron (MLP) module to construct neural networks, which is build upon Flax.
The following example shows the MLP with input and output dimension one and intermediate dimension 20.
import jaxfss
import jax
import jax.numpy as jnp
mlp = jaxfss.MLP(features=[20,20,1])
mlp_params = mlp.init(jax.random.PRNGKey(0), jnp.array([[1]]))
Default activation function of MLP
is a sigmoid function. You can change this via act
argument.
We also provide the MLP with rational activation function.
import jaxfss
import jax.numpy as jnp
from jax import random
mlp = jaxfss.RationalMLP(features=[20,20,1])
mlp_params = mlp.init(jax.random.PRNGKey(0), jnp.array([[1]]))
In RationalMLP
, the activation function is approximated by a rational function,
and the parameters of the activation function are also optimized during the learning process. See arXiv:2004.01902 for the original paper.
2. Data handler#
In neural networks, it is very important to pre-normalize the training data.
CriticalData
is a module that does the pre-processing for you.
import jaxfss
Ls = ... # systemsize arrays
Ts = ... # temperature arrays
As = ... # observable arrays
As_err = ... # observable error arrays
dataset = jaxfss.CriticalData(Ls, Ts, As, As_err)
train_data = dataset.training_data # dict with "system_size", "temperature", "observable", "observable_var"
You can also initialize the module from file:
import jaxfss
dataset = jaxfss.CriticalData.from_file(fname="filename.txt")
train_data = dataset.training_data
3. Loss function#
We provide a helper function for constructing loss function.
MSELoss
is the mean squared error, and NLLLoss
is the negative log likelihood loss.
The function name is derived from pytorch.
Check out how it’s being used in the Examples.
import jaxfss
def loss_fn(params):
...
y_true = ...
y_pred = ...
return jaxfss.MSELoss(y_true, y_pred)
MSELoss
reads
NLLLoss
reads
4. Train#
Once everything is ready, all you need to do is learn!
fit
function is a wrapper that learns using optax.
We note that init_params
should be a dict with keys mlp
and fss
.
import jaxfss
import optax
init_params = {
"mlp": mlp_params,
"fss": ...
}
def loss_fn(params):
...
return ...
optimizer = optax.adam(learning_rate=10**-3)
steps = 10**4
params, losses, fsses = jaxfss.fit(loss_fn, optimizer, init_params, steps)
fit
function returns the final params
together with values of loss and fss in the learning process.