Skip to content

Instantly share code, notes, and snippets.

@aryan-f
Last active January 11, 2024 11:43
Show Gist options
  • Save aryan-f/8a416f33a27d73a149f92ce4708beb40 to your computer and use it in GitHub Desktop.
Save aryan-f/8a416f33a27d73a149f92ce4708beb40 to your computer and use it in GitHub Desktop.
Standard Scaler for PyTorch Tensors
import torch
class StandardScaler:
def __init__(self, mean=None, std=None, epsilon=1e-7):
"""Standard Scaler.
The class can be used to normalize PyTorch Tensors using native functions. The module does not expect the
tensors to be of any specific shape; as long as the features are the last dimension in the tensor, the module
will work fine.
:param mean: The mean of the features. The property will be set after a call to fit.
:param std: The standard deviation of the features. The property will be set after a call to fit.
:param epsilon: Used to avoid a Division-By-Zero exception.
"""
self.mean = mean
self.std = std
self.epsilon = epsilon
def fit(self, values):
dims = list(range(values.dim() - 1))
self.mean = torch.mean(values, dim=dims)
self.std = torch.std(values, dim=dims)
def transform(self, values):
return (values - self.mean) / (self.std + self.epsilon)
def fit_transform(self, values):
self.fit(values)
return self.transform(values)
@teodorkasap
Copy link

thank you

@ahmedius2
Copy link

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment