Last active
January 11, 2024 11:43
-
-
Save aryan-f/8a416f33a27d73a149f92ce4708beb40 to your computer and use it in GitHub Desktop.
Standard Scaler for PyTorch Tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class StandardScaler: | |
def __init__(self, mean=None, std=None, epsilon=1e-7): | |
"""Standard Scaler. | |
The class can be used to normalize PyTorch Tensors using native functions. The module does not expect the | |
tensors to be of any specific shape; as long as the features are the last dimension in the tensor, the module | |
will work fine. | |
:param mean: The mean of the features. The property will be set after a call to fit. | |
:param std: The standard deviation of the features. The property will be set after a call to fit. | |
:param epsilon: Used to avoid a Division-By-Zero exception. | |
""" | |
self.mean = mean | |
self.std = std | |
self.epsilon = epsilon | |
def fit(self, values): | |
dims = list(range(values.dim() - 1)) | |
self.mean = torch.mean(values, dim=dims) | |
self.std = torch.std(values, dim=dims) | |
def transform(self, values): | |
return (values - self.mean) / (self.std + self.epsilon) | |
def fit_transform(self, values): | |
self.fit(values) | |
return self.transform(values) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thank you