sadhaklal commited on
Commit
365a3de
·
verified ·
1 Parent(s): 1dbfe55

added "Usage" section to README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -0
README.md CHANGED
@@ -25,7 +25,52 @@ Experiment tracking: https://wandb.ai/sadhaklal/mlp-california-housing
25
  ## Usage
26
 
27
  ```
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ```
30
 
31
  ## Metric
 
25
  ## Usage
26
 
27
  ```
28
+ from sklearn.datasets import fetch_california_housing
29
 
30
+ housing = fetch_california_housing(as_frame=True)
31
+
32
+ from sklearn.model_selection import train_test_split
33
+
34
+ X_train_full, X_test, y_train_full, y_test = train_test_split(housing['data'], housing['target'], test_size=0.25, random_state=42)
35
+ X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, test_size=0.25, random_state=42)
36
+
37
+ X_means, X_stds = X_train.mean(axis=0), X_train.std(axis=0)
38
+ X_train = (X_train - X_means) / X_stds
39
+ X_valid = (X_valid - X_means) / X_stds
40
+ X_test = (X_test - X_means) / X_stds
41
+
42
+ import torch
43
+
44
+ device = torch.device("cpu")
45
+
46
+ import torch.nn as nn
47
+ from huggingface_hub import PyTorchModelHubMixin
48
+
49
+ class MLP(nn.Module, PyTorchModelHubMixin):
50
+ def __init__(self):
51
+ super().__init__()
52
+ self.fc1 = nn.Linear(8, 50)
53
+ self.fc2 = nn.Linear(50, 50)
54
+ self.fc3 = nn.Linear(50, 50)
55
+ self.fc4 = nn.Linear(50, 1)
56
+
57
+ def forward(self, x):
58
+ act = torch.relu(self.fc1(x))
59
+ act = torch.relu(self.fc2(act))
60
+ act = torch.relu(self.fc3(act))
61
+ return self.fc4(act)
62
+
63
+ model = MLP.from_pretrained("sadhaklal/mlp-california-housing")
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ # Let's predict on 3 unseen examples from the test set:
68
+ print(f"Ground truth housing prices: {y_test.values[:3]}")
69
+ x_new = torch.tensor(X_test.values[:3], dtype=torch.float32)
70
+ x_new = x_new.to(device)
71
+ with torch.no_grad():
72
+ preds = model(x_new)
73
+ print(f"Predicted housing prices: {preds.squeeze()}")
74
  ```
75
 
76
  ## Metric