TF-Keras
English
markub3327 commited on
Commit
5671ee4
·
1 Parent(s): 62ca2e7
Files changed (6) hide show
  1. README.md +41 -3
  2. Simulator.ipynb +1071 -0
  3. Training.ipynb +1433 -0
  4. img/Solar_Transformer.png +0 -0
  5. img/output.png +0 -0
  6. models/model-best.h5 +3 -0
README.md CHANGED
@@ -1,3 +1,41 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solar Transformer
2
+
3
+ Please check our paper [Solar Irradiance Forecasting with Transformer model
4
+ ](https://www.mdpi.com/2076-3417/12/17/8852) for more details.
5
+
6
+ [![Issues](https://img.shields.io/github/issues/markub3327/Solar-Transformer)](https://github.com/markub3327/Solar-Transformer/issues)
7
+ ![Commits](https://img.shields.io/github/commit-activity/w/markub3327/Solar-Transformer)
8
+ ![Size](https://img.shields.io/github/repo-size/markub3327/Solar-Transformer)
9
+
10
+ ## Paper
11
+
12
+ * Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, Ł.; Polosukhin, I. Attention is all you need. Advances in neural information processing systems 2017, 30.
13
+ * Dosovitskiy, A.; Beyer, L.; Kolesnikov, A.; Weissenborn, D.; Zhai, X.; Unterthiner, T.; Dehghani, M.; Minderer, M.; Heigold, G.; Gelly, S.; Uszkoreit, J. An image is worth 16x16 words: Transformers for image recognition at scale. 2020, arXiv preprint arXiv:2010.11929.
14
+ * Bao, H.; Dong, L.; Wei, F. Beit: Bert pre-training of image transformers. 2021, arXiv preprint arXiv:2106.08254.
15
+ * Brahma, B.; Wadhvani, R. Solar irradiance forecasting based on deep learning methodologies and multi-site data. Sym-metry 2020, 12(11), p.1830. Available online: https://www.mdpi.com/2073-8994/12/11/1830
16
+
17
+ ## About
18
+
19
+ Solar energy is one of the most popular sources of renewable energy today. It is therefore essential to be able to predict solar power generation and adapt the energy needs to these predictions. This paper uses Transformer deep neural network model, in which the attention mechanism is typically applied in NLP or vision problems. Here it is extended by combining features based on their spatio-temporal properties in solar irradiance prediction. The results were predicted for arbitrary long-time horizons since the prediction is always 1 day ahead, which can be included at the end along the timestep axis of the input data and the first timestep representing the oldest timestep removed. A maximum worst-case mean absolute percentage error of 3.45% for the 1 day-ahead prediction was achieved, thus providing better results than the directly competing method.
20
+
21
+ ## Dataset
22
+
23
+ [NASA POWER Project](https://power.larc.nasa.gov)
24
+
25
+ Solar irradiance + Weather (temperature, humidity, pressure, wind speed, wind direction)
26
+
27
+ ## Model
28
+
29
+ <p align="center">
30
+ <img src="img/Solar_Transformer.png">
31
+ </p>
32
+
33
+ ## Results
34
+
35
+ <p align="center">
36
+ <img src="img/output.png">
37
+ </p>
38
+
39
+ ----------------------------------
40
+
41
+ **Frameworks:** TensorFlow, NumPy, Pandas, WanDB, Seaborn, Matplotlib
Simulator.ipynb ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install wandb tensorflow_probability tensorflow_addons"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from tensorflow.keras.layers import Add, Dense, Dropout, Layer, LayerNormalization, MultiHeadAttention\n",
19
+ "from tensorflow.keras.models import Model\n",
20
+ "from tensorflow.keras.initializers import TruncatedNormal\n",
21
+ "from tensorflow.keras.metrics import MeanSquaredError, RootMeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError\n",
22
+ "from tensorflow_addons.metrics import RSquare\n",
23
+ "\n",
24
+ "import pandas as pd\n",
25
+ "import tensorflow as tf\n",
26
+ "import numpy as np\n",
27
+ "import matplotlib.pyplot as plt\n",
28
+ "import seaborn as sns"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Plotting"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def plot_prediction(targets, predictions, max_subplots=3):\n",
45
+ " plt.figure(figsize=(12, 15))\n",
46
+ " max_n = min(max_subplots, len(targets))\n",
47
+ " for n in range(max_n):\n",
48
+ " # input\n",
49
+ " plt.subplot(max_n, 1, n+1)\n",
50
+ " plt.ylabel('Solar irradiance [kW-hr/m^2/day]', fontfamily=\"Arial\", fontsize=16)\n",
51
+ " plt.plot(np.arange(targets.shape[1]-horizon), targets[n, :-horizon, 0, -1], label='Inputs', marker='.', zorder=-10)\n",
52
+ "\n",
53
+ " # real\n",
54
+ " plt.scatter(np.arange(1, targets.shape[1]), targets[n, 1:, 0, -1], edgecolors='k', label='Targets', c='#2cb01d', s=64)\n",
55
+ " \n",
56
+ " # predicted\n",
57
+ " plt.scatter(np.arange(1, targets.shape[1]), predictions[n, :, 0, -1], marker='X', edgecolors='k', label='Predictions', c='#fe7e0f', s=64)\n",
58
+ "\n",
59
+ " if n == 0:\n",
60
+ " plt.legend()\n",
61
+ "\n",
62
+ " plt.xlabel('Time [day]', fontfamily=\"Arial\", fontsize=16)\n",
63
+ " plt.show()"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "def patch_similarity_plot(pos):\n",
73
+ " similarity_scores = np.dot(\n",
74
+ " pos, np.transpose(pos)\n",
75
+ " ) / (\n",
76
+ " np.linalg.norm(pos, axis=-1)\n",
77
+ " * np.linalg.norm(pos, axis=-1)\n",
78
+ " )\n",
79
+ "\n",
80
+ " plt.figure(figsize=(7, 7), dpi=300)\n",
81
+ " ax = sns.heatmap(similarity_scores, center=0)\n",
82
+ " ax.set_title(\"Spatial Positional Embedding\", fontfamily=\"Arial\", fontsize=16)\n",
83
+ " ax.set_xlabel(\"Patch\", fontfamily=\"Arial\", fontsize=16)\n",
84
+ " ax.set_ylabel(\"Patch\", fontfamily=\"Arial\", fontsize=16)\n",
85
+ " plt.show()\n",
86
+ "\n",
87
+ "def timestep_similarity_plot(pos):\n",
88
+ " similarity_scores = np.dot(\n",
89
+ " pos, np.transpose(pos)\n",
90
+ " ) / (\n",
91
+ " np.linalg.norm(pos, axis=-1)\n",
92
+ " * np.linalg.norm(pos, axis=-1)\n",
93
+ " )\n",
94
+ "\n",
95
+ " plt.figure(figsize=(7, 7), dpi=300)\n",
96
+ " ax = sns.heatmap(similarity_scores, center=0)\n",
97
+ " ax.set_title(\"Temporal Positional Embedding\", fontfamily=\"Arial\", fontsize=16)\n",
98
+ " ax.set_xlabel(\"Timestep\", fontfamily=\"Arial\", fontsize=16)\n",
99
+ " ax.set_ylabel(\"Timestep\", fontfamily=\"Arial\", fontsize=16)\n",
100
+ " plt.show()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "metadata": {},
106
+ "source": [
107
+ "## Layer"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "class Normalization(tf.keras.layers.experimental.preprocessing.PreprocessingLayer):\n",
117
+ " \"\"\"A preprocessing layer which normalizes continuous features.\n",
118
+ " This layer will shift and scale inputs into a distribution centered around\n",
119
+ " 0 with standard deviation 1. It accomplishes this by precomputing the mean\n",
120
+ " and variance of the data, and calling `(input - mean) / sqrt(var)` at\n",
121
+ " runtime.\n",
122
+ " The mean and variance values for the layer must be either supplied on\n",
123
+ " construction or learned via `adapt()`. `adapt()` will compute the mean and\n",
124
+ " variance of the data and store them as the layer's weights. `adapt()` should\n",
125
+ " be called before `fit()`, `evaluate()`, or `predict()`.\n",
126
+ " For an overview and full list of preprocessing layers, see the preprocessing\n",
127
+ " [guide](https://www.tensorflow.org/guide/keras/preprocessing_layers).\n",
128
+ " Args:\n",
129
+ " axis: Integer, tuple of integers, or None. The axis or axes that should\n",
130
+ " have a separate mean and variance for each index in the shape. For\n",
131
+ " example, if shape is `(None, 5)` and `axis=1`, the layer will track 5\n",
132
+ " separate mean and variance values for the last axis. If `axis` is set\n",
133
+ " to `None`, the layer will normalize all elements in the input by a\n",
134
+ " scalar mean and variance. Defaults to -1, where the last axis of the\n",
135
+ " input is assumed to be a feature dimension and is normalized per\n",
136
+ " index. Note that in the specific case of batched scalar inputs where\n",
137
+ " the only axis is the batch axis, the default will normalize each index\n",
138
+ " in the batch separately. In this case, consider passing `axis=None`.\n",
139
+ " mean: The mean value(s) to use during normalization. The passed value(s)\n",
140
+ " will be broadcast to the shape of the kept axes above; if the value(s)\n",
141
+ " cannot be broadcast, an error will be raised when this layer's\n",
142
+ " `build()` method is called.\n",
143
+ " variance: The variance value(s) to use during normalization. The passed\n",
144
+ " value(s) will be broadcast to the shape of the kept axes above; if the\n",
145
+ " value(s) cannot be broadcast, an error will be raised when this\n",
146
+ " layer's `build()` method is called.\n",
147
+ " invert: If True, this layer will apply the inverse transformation\n",
148
+ " to its inputs: it would turn a normalized input back into its\n",
149
+ " original form.\n",
150
+ " Examples:\n",
151
+ " Calculate a global mean and variance by analyzing the dataset in `adapt()`.\n",
152
+ " >>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')\n",
153
+ " >>> input_data = np.array([1., 2., 3.], dtype='float32')\n",
154
+ " >>> layer = tf.keras.layers.Normalization(axis=None)\n",
155
+ " >>> layer.adapt(adapt_data)\n",
156
+ " >>> layer(input_data)\n",
157
+ " <tf.Tensor: shape=(3,), dtype=float32, numpy=\n",
158
+ " array([-1.4142135, -0.70710677, 0.], dtype=float32)>\n",
159
+ " Calculate a mean and variance for each index on the last axis.\n",
160
+ " >>> adapt_data = np.array([[0., 7., 4.],\n",
161
+ " ... [2., 9., 6.],\n",
162
+ " ... [0., 7., 4.],\n",
163
+ " ... [2., 9., 6.]], dtype='float32')\n",
164
+ " >>> input_data = np.array([[0., 7., 4.]], dtype='float32')\n",
165
+ " >>> layer = tf.keras.layers.Normalization(axis=-1)\n",
166
+ " >>> layer.adapt(adapt_data)\n",
167
+ " >>> layer(input_data)\n",
168
+ " <tf.Tensor: shape=(1, 3), dtype=float32, numpy=\n",
169
+ " array([-1., -1., -1.], dtype=float32)>\n",
170
+ " Pass the mean and variance directly.\n",
171
+ " >>> input_data = np.array([[1.], [2.], [3.]], dtype='float32')\n",
172
+ " >>> layer = tf.keras.layers.Normalization(mean=3., variance=2.)\n",
173
+ " >>> layer(input_data)\n",
174
+ " <tf.Tensor: shape=(3, 1), dtype=float32, numpy=\n",
175
+ " array([[-1.4142135 ],\n",
176
+ " [-0.70710677],\n",
177
+ " [ 0. ]], dtype=float32)>\n",
178
+ " Use the layer to de-normalize inputs (after adapting the layer).\n",
179
+ " >>> adapt_data = np.array([[0., 7., 4.],\n",
180
+ " ... [2., 9., 6.],\n",
181
+ " ... [0., 7., 4.],\n",
182
+ " ... [2., 9., 6.]], dtype='float32')\n",
183
+ " >>> input_data = np.array([[1., 2., 3.]], dtype='float32')\n",
184
+ " >>> layer = tf.keras.layers.Normalization(axis=-1, invert=True)\n",
185
+ " >>> layer.adapt(adapt_data)\n",
186
+ " >>> layer(input_data)\n",
187
+ " <tf.Tensor: shape=(1, 3), dtype=float32, numpy=\n",
188
+ " array([2., 10., 8.], dtype=float32)>\n",
189
+ " \"\"\"\n",
190
+ "\n",
191
+ " def __init__(\n",
192
+ " self, axis=-1, mean=None, variance=None, invert=False, **kwargs\n",
193
+ " ):\n",
194
+ " super().__init__(**kwargs)\n",
195
+ "\n",
196
+ " # Standardize `axis` to a tuple.\n",
197
+ " if axis is None:\n",
198
+ " axis = ()\n",
199
+ " elif isinstance(axis, int):\n",
200
+ " axis = (axis,)\n",
201
+ " else:\n",
202
+ " axis = tuple(axis)\n",
203
+ " self.axis = axis\n",
204
+ "\n",
205
+ " # Set `mean` and `variance` if passed.\n",
206
+ " if isinstance(mean, tf.Variable):\n",
207
+ " raise ValueError(\n",
208
+ " \"Normalization does not support passing a Variable \"\n",
209
+ " \"for the `mean` init arg.\"\n",
210
+ " )\n",
211
+ " if isinstance(variance, tf.Variable):\n",
212
+ " raise ValueError(\n",
213
+ " \"Normalization does not support passing a Variable \"\n",
214
+ " \"for the `variance` init arg.\"\n",
215
+ " )\n",
216
+ " if (mean is not None) != (variance is not None):\n",
217
+ " raise ValueError(\n",
218
+ " \"When setting values directly, both `mean` and `variance` \"\n",
219
+ " \"must be set. Got mean: {} and variance: {}\".format(\n",
220
+ " mean, variance\n",
221
+ " )\n",
222
+ " )\n",
223
+ " self.input_mean = mean\n",
224
+ " self.input_variance = variance\n",
225
+ " self.invert = invert\n",
226
+ "\n",
227
+ " def build(self, input_shape):\n",
228
+ " super().build(input_shape)\n",
229
+ "\n",
230
+ " if isinstance(input_shape, (list, tuple)) and all(\n",
231
+ " isinstance(shape, tf.TensorShape) for shape in input_shape\n",
232
+ " ):\n",
233
+ " raise ValueError(\n",
234
+ " \"Normalization only accepts a single input. If you are \"\n",
235
+ " \"passing a python list or tuple as a single input, \"\n",
236
+ " \"please convert to a numpy array or `tf.Tensor`.\"\n",
237
+ " )\n",
238
+ "\n",
239
+ " input_shape = tf.TensorShape(input_shape).as_list()\n",
240
+ " ndim = len(input_shape)\n",
241
+ "\n",
242
+ " if any(a < -ndim or a >= ndim for a in self.axis):\n",
243
+ " raise ValueError(\n",
244
+ " \"All `axis` values must be in the range [-ndim, ndim). \"\n",
245
+ " \"Found ndim: `{}`, axis: {}\".format(ndim, self.axis)\n",
246
+ " )\n",
247
+ "\n",
248
+ " # Axes to be kept, replacing negative values with positive equivalents.\n",
249
+ " # Sorted to avoid transposing axes.\n",
250
+ " self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])\n",
251
+ " # All axes to be kept should have known shape.\n",
252
+ " for d in self._keep_axis:\n",
253
+ " if input_shape[d] is None:\n",
254
+ " raise ValueError(\n",
255
+ " \"All `axis` values to be kept must have known shape. \"\n",
256
+ " \"Got axis: {}, \"\n",
257
+ " \"input shape: {}, with unknown axis at index: {}\".format(\n",
258
+ " self.axis, input_shape, d\n",
259
+ " )\n",
260
+ " )\n",
261
+ " # Axes to be reduced.\n",
262
+ " self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]\n",
263
+ " # 1 if an axis should be reduced, 0 otherwise.\n",
264
+ " self._reduce_axis_mask = [\n",
265
+ " 0 if d in self._keep_axis else 1 for d in range(ndim)\n",
266
+ " ]\n",
267
+ " # Broadcast any reduced axes.\n",
268
+ " self._broadcast_shape = [\n",
269
+ " input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)\n",
270
+ " ]\n",
271
+ " mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)\n",
272
+ "\n",
273
+ " if self.input_mean is None:\n",
274
+ " self.adapt_mean = self.add_weight(\n",
275
+ " name=\"mean\",\n",
276
+ " shape=mean_and_var_shape,\n",
277
+ " dtype=self.compute_dtype,\n",
278
+ " initializer=\"zeros\",\n",
279
+ " trainable=False,\n",
280
+ " )\n",
281
+ " self.adapt_variance = self.add_weight(\n",
282
+ " name=\"variance\",\n",
283
+ " shape=mean_and_var_shape,\n",
284
+ " dtype=self.compute_dtype,\n",
285
+ " initializer=\"ones\",\n",
286
+ " trainable=False,\n",
287
+ " )\n",
288
+ " self.count = self.add_weight(\n",
289
+ " name=\"count\",\n",
290
+ " shape=(),\n",
291
+ " dtype=tf.int64,\n",
292
+ " initializer=\"zeros\",\n",
293
+ " trainable=False,\n",
294
+ " )\n",
295
+ " self.finalize_state()\n",
296
+ " else:\n",
297
+ " # In the no adapt case, make constant tensors for mean and variance\n",
298
+ " # with proper broadcast shape for use during call.\n",
299
+ " mean = self.input_mean * np.ones(mean_and_var_shape)\n",
300
+ " variance = self.input_variance * np.ones(mean_and_var_shape)\n",
301
+ " mean = tf.reshape(mean, self._broadcast_shape)\n",
302
+ " variance = tf.reshape(variance, self._broadcast_shape)\n",
303
+ " self.mean = tf.cast(mean, self.compute_dtype)\n",
304
+ " self.variance = tf.cast(variance, self.compute_dtype)\n",
305
+ "\n",
306
+ " # We override this method solely to generate a docstring.\n",
307
+ " def adapt(self, data, batch_size=None, steps=None):\n",
308
+ " \"\"\"Computes the mean and variance of values in a dataset.\n",
309
+ " Calling `adapt()` on a `Normalization` layer is an alternative to\n",
310
+ " passing in `mean` and `variance` arguments during layer construction. A\n",
311
+ " `Normalization` layer should always either be adapted over a dataset or\n",
312
+ " passed `mean` and `variance`.\n",
313
+ " During `adapt()`, the layer will compute a `mean` and `variance`\n",
314
+ " separately for each position in each axis specified by the `axis`\n",
315
+ " argument. To calculate a single `mean` and `variance` over the input\n",
316
+ " data, simply pass `axis=None`.\n",
317
+ " In order to make `Normalization` efficient in any distribution context,\n",
318
+ " the computed mean and variance are kept static with respect to any\n",
319
+ " compiled `tf.Graph`s that call the layer. As a consequence, if the layer\n",
320
+ " is adapted a second time, any models using the layer should be\n",
321
+ " re-compiled. For more information see\n",
322
+ " `tf.keras.layers.experimental.preprocessing.PreprocessingLayer.adapt`.\n",
323
+ " `adapt()` is meant only as a single machine utility to compute layer\n",
324
+ " state. To analyze a dataset that cannot fit on a single machine, see\n",
325
+ " [Tensorflow Transform](\n",
326
+ " https://www.tensorflow.org/tfx/transform/get_started)\n",
327
+ " for a multi-machine, map-reduce solution.\n",
328
+ " Arguments:\n",
329
+ " data: The data to train on. It can be passed either as a\n",
330
+ " `tf.data.Dataset`, or as a numpy array.\n",
331
+ " batch_size: Integer or `None`.\n",
332
+ " Number of samples per state update.\n",
333
+ " If unspecified, `batch_size` will default to 32.\n",
334
+ " Do not specify the `batch_size` if your data is in the\n",
335
+ " form of datasets, generators, or `keras.utils.Sequence` instances\n",
336
+ " (since they generate batches).\n",
337
+ " steps: Integer or `None`.\n",
338
+ " Total number of steps (batches of samples)\n",
339
+ " When training with input tensors such as\n",
340
+ " TensorFlow data tensors, the default `None` is equal to\n",
341
+ " the number of samples in your dataset divided by\n",
342
+ " the batch size, or 1 if that cannot be determined. If x is a\n",
343
+ " `tf.data` dataset, and 'steps' is None, the epoch will run until\n",
344
+ " the input dataset is exhausted. When passing an infinitely\n",
345
+ " repeating dataset, you must specify the `steps` argument. This\n",
346
+ " argument is not supported with array inputs.\n",
347
+ " \"\"\"\n",
348
+ " super().adapt(data, batch_size=batch_size, steps=steps)\n",
349
+ "\n",
350
+ " def update_state(self, data):\n",
351
+ " if self.input_mean is not None:\n",
352
+ " raise ValueError(\n",
353
+ " \"Cannot `adapt` a Normalization layer that is initialized with \"\n",
354
+ " \"static `mean` and `variance`, \"\n",
355
+ " \"you passed mean {} and variance {}.\".format(\n",
356
+ " self.input_mean, self.input_variance\n",
357
+ " )\n",
358
+ " )\n",
359
+ "\n",
360
+ " if not self.built:\n",
361
+ " raise RuntimeError(\"`build` must be called before `update_state`.\")\n",
362
+ "\n",
363
+ " data = self._standardize_inputs(data)\n",
364
+ " data = tf.cast(data, self.adapt_mean.dtype)\n",
365
+ " batch_mean, batch_variance = tf.nn.moments(data, axes=self._reduce_axis)\n",
366
+ " batch_shape = tf.shape(data, out_type=self.count.dtype)\n",
367
+ " if self._reduce_axis:\n",
368
+ " batch_reduce_shape = tf.gather(batch_shape, self._reduce_axis)\n",
369
+ " batch_count = tf.reduce_prod(batch_reduce_shape)\n",
370
+ " else:\n",
371
+ " batch_count = 1\n",
372
+ "\n",
373
+ " total_count = batch_count + self.count\n",
374
+ " batch_weight = tf.cast(batch_count, dtype=self.compute_dtype) / tf.cast(\n",
375
+ " total_count, dtype=self.compute_dtype\n",
376
+ " )\n",
377
+ " existing_weight = 1.0 - batch_weight\n",
378
+ "\n",
379
+ " total_mean = (\n",
380
+ " self.adapt_mean * existing_weight + batch_mean * batch_weight\n",
381
+ " )\n",
382
+ " # The variance is computed using the lack-of-fit sum of squares\n",
383
+ " # formula (see\n",
384
+ " # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).\n",
385
+ " total_variance = (\n",
386
+ " self.adapt_variance + (self.adapt_mean - total_mean) ** 2\n",
387
+ " ) * existing_weight + (\n",
388
+ " batch_variance + (batch_mean - total_mean) ** 2\n",
389
+ " ) * batch_weight\n",
390
+ " self.adapt_mean.assign(total_mean)\n",
391
+ " self.adapt_variance.assign(total_variance)\n",
392
+ " self.count.assign(total_count)\n",
393
+ "\n",
394
+ " def reset_state(self):\n",
395
+ " if self.input_mean is not None or not self.built:\n",
396
+ " return\n",
397
+ "\n",
398
+ " self.adapt_mean.assign(tf.zeros_like(self.adapt_mean))\n",
399
+ " self.adapt_variance.assign(tf.ones_like(self.adapt_variance))\n",
400
+ " self.count.assign(tf.zeros_like(self.count))\n",
401
+ "\n",
402
+ " def finalize_state(self):\n",
403
+ " if self.input_mean is not None or not self.built:\n",
404
+ " return\n",
405
+ "\n",
406
+ " # In the adapt case, we make constant tensors for mean and variance with\n",
407
+ " # proper broadcast shape and dtype each time `finalize_state` is called.\n",
408
+ " self.mean = tf.reshape(self.adapt_mean, self._broadcast_shape)\n",
409
+ " self.mean = tf.cast(self.mean, self.compute_dtype)\n",
410
+ " self.variance = tf.reshape(self.adapt_variance, self._broadcast_shape)\n",
411
+ " self.variance = tf.cast(self.variance, self.compute_dtype)\n",
412
+ "\n",
413
+ " def call(self, inputs):\n",
414
+ " inputs = self._standardize_inputs(inputs)\n",
415
+ " # The base layer automatically casts floating-point inputs, but we\n",
416
+ " # explicitly cast here to also allow integer inputs to be passed\n",
417
+ " inputs = tf.cast(inputs, self.compute_dtype)\n",
418
+ " if self.invert:\n",
419
+ " return (inputs + self.mean) * tf.maximum(\n",
420
+ " tf.sqrt(self.variance), tf.keras.backend.epsilon()\n",
421
+ " )\n",
422
+ " else:\n",
423
+ " return (inputs - self.mean) / tf.maximum(\n",
424
+ " tf.sqrt(self.variance), tf.keras.backend.epsilon()\n",
425
+ " )\n",
426
+ "\n",
427
+ " def compute_output_shape(self, input_shape):\n",
428
+ " return input_shape\n",
429
+ "\n",
430
+ " def compute_output_signature(self, input_spec):\n",
431
+ " return input_spec\n",
432
+ "\n",
433
+ " def get_config(self):\n",
434
+ " config = super().get_config()\n",
435
+ " config.update(\n",
436
+ " {\n",
437
+ " \"axis\": self.axis,\n",
438
+ " \"mean\": tf.keras.layers.experimental.preprocessing.preprocessing_utils.utils.listify_tensors(self.input_mean),\n",
439
+ " \"variance\": tf.keras.layers.experimental.preprocessing.preprocessing_utils.utils.listify_tensors(self.input_variance),\n",
440
+ " }\n",
441
+ " )\n",
442
+ " return config\n",
443
+ "\n",
444
+ " def _standardize_inputs(self, inputs):\n",
445
+ " inputs = tf.convert_to_tensor(inputs)\n",
446
+ " if inputs.dtype != self.compute_dtype:\n",
447
+ " inputs = tf.cast(inputs, self.compute_dtype)\n",
448
+ " return inputs"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": null,
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "class PositionalEmbedding(Layer):\n",
458
+ " def __init__(self, units, dropout_rate, **kwargs):\n",
459
+ " super(PositionalEmbedding, self).__init__(**kwargs)\n",
460
+ "\n",
461
+ " self.units = units\n",
462
+ "\n",
463
+ " self.projection = Dense(units, kernel_initializer=TruncatedNormal(stddev=0.02))\n",
464
+ " self.dropout = Dropout(rate=dropout_rate)\n",
465
+ "\n",
466
+ " def build(self, input_shape):\n",
467
+ " super(PositionalEmbedding, self).build(input_shape)\n",
468
+ "\n",
469
+ " print(\"pos_embbeding: \", input_shape)\n",
470
+ " self.temporal_position = self.add_weight(\n",
471
+ " name=\"temporal_position\",\n",
472
+ " shape=(1, input_shape[1], 1, self.units),\n",
473
+ " initializer=TruncatedNormal(stddev=0.02),\n",
474
+ " trainable=True,\n",
475
+ " )\n",
476
+ " self.spatial_position = self.add_weight(\n",
477
+ " name=\"spatial_position\",\n",
478
+ " shape=(1, 1, input_shape[2], self.units),\n",
479
+ " initializer=TruncatedNormal(stddev=0.02),\n",
480
+ " trainable=True,\n",
481
+ " )\n",
482
+ "\n",
483
+ " def call(self, inputs, training):\n",
484
+ " x = self.projection(inputs)\n",
485
+ " x += self.temporal_position\n",
486
+ " x += self.spatial_position\n",
487
+ "\n",
488
+ " return self.dropout(x, training=training)"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "class Encoder(Layer):\n",
498
+ " def __init__(\n",
499
+ " self, embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate, **kwargs\n",
500
+ " ):\n",
501
+ " super(Encoder, self).__init__(**kwargs)\n",
502
+ "\n",
503
+ " # Multi-head Attention\n",
504
+ " self.mha = MultiHeadAttention(\n",
505
+ " num_heads=num_heads,\n",
506
+ " key_dim=embed_dim,\n",
507
+ " dropout=attention_dropout_rate,\n",
508
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
509
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
510
+ " )\n",
511
+ "\n",
512
+ " # Point wise feed forward network\n",
513
+ " self.dense_0 = Dense(\n",
514
+ " units=mlp_dim,\n",
515
+ " activation=\"gelu\",\n",
516
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
517
+ " )\n",
518
+ " self.dense_1 = Dense(\n",
519
+ " units=embed_dim, kernel_initializer=TruncatedNormal(stddev=0.02)\n",
520
+ " )\n",
521
+ "\n",
522
+ " self.dropout_0 = Dropout(rate=dropout_rate)\n",
523
+ " self.dropout_1 = Dropout(rate=dropout_rate)\n",
524
+ "\n",
525
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
526
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
527
+ "\n",
528
+ " self.add_0 = Add()\n",
529
+ " self.add_1 = Add()\n",
530
+ "\n",
531
+ " def call(self, inputs, training):\n",
532
+ " # Attention block\n",
533
+ " x = self.norm_0(inputs)\n",
534
+ " x = self.mha(\n",
535
+ " query=x,\n",
536
+ " key=x,\n",
537
+ " value=x,\n",
538
+ " training=training,\n",
539
+ " )\n",
540
+ " x = self.dropout_0(x, training=training)\n",
541
+ " x = self.add_0([x, inputs])\n",
542
+ "\n",
543
+ " # MLP block\n",
544
+ " y = self.norm_1(x)\n",
545
+ " y = self.dense_0(y)\n",
546
+ " y = self.dense_1(y)\n",
547
+ " y = self.dropout_1(y, training=training)\n",
548
+ "\n",
549
+ " return self.add_1([x, y])"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": null,
555
+ "metadata": {},
556
+ "outputs": [],
557
+ "source": [
558
+ "class Decoder(Layer):\n",
559
+ " def __init__(\n",
560
+ " self, embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate, **kwargs\n",
561
+ " ):\n",
562
+ " super(Decoder, self).__init__(**kwargs)\n",
563
+ "\n",
564
+ " # MultiHeadAttention\n",
565
+ " self.mha_0 = MultiHeadAttention(\n",
566
+ " num_heads=num_heads,\n",
567
+ " key_dim=embed_dim,\n",
568
+ " dropout=attention_dropout_rate,\n",
569
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
570
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
571
+ " )\n",
572
+ " self.mha_1 = MultiHeadAttention(\n",
573
+ " num_heads=num_heads,\n",
574
+ " key_dim=embed_dim,\n",
575
+ " dropout=attention_dropout_rate,\n",
576
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
577
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
578
+ " )\n",
579
+ "\n",
580
+ " # Point wise feed forward network\n",
581
+ " self.dense_0 = Dense(\n",
582
+ " units=mlp_dim,\n",
583
+ " activation=\"gelu\",\n",
584
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
585
+ " )\n",
586
+ " self.dense_1 = Dense(\n",
587
+ " units=embed_dim, kernel_initializer=TruncatedNormal(stddev=0.02)\n",
588
+ " )\n",
589
+ "\n",
590
+ " self.dropout_0 = Dropout(rate=dropout_rate)\n",
591
+ " self.dropout_1 = Dropout(rate=dropout_rate)\n",
592
+ " self.dropout_2 = Dropout(rate=dropout_rate)\n",
593
+ "\n",
594
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
595
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
596
+ " self.norm_2 = LayerNormalization(epsilon=1e-12)\n",
597
+ "\n",
598
+ " self.add_0 = Add()\n",
599
+ " self.add_1 = Add()\n",
600
+ " self.add_2 = Add()\n",
601
+ "\n",
602
+ " def call(self, inputs, enc_output, training):\n",
603
+ " # Attention block\n",
604
+ " x = self.norm_0(inputs)\n",
605
+ " x = self.mha_0(\n",
606
+ " query=x,\n",
607
+ " key=x,\n",
608
+ " value=x,\n",
609
+ " training=training,\n",
610
+ " )\n",
611
+ " x = self.dropout_0(x, training=training)\n",
612
+ " x = self.add_0([x, inputs])\n",
613
+ "\n",
614
+ " # Attention block\n",
615
+ " y = self.norm_1(x)\n",
616
+ " y = self.mha_1(\n",
617
+ " query=y,\n",
618
+ " key=enc_output,\n",
619
+ " value=enc_output,\n",
620
+ " training=training,\n",
621
+ " )\n",
622
+ " y = self.dropout_1(y, training=training)\n",
623
+ " y = self.add_1([x, y])\n",
624
+ "\n",
625
+ " # MLP block\n",
626
+ " z = self.norm_2(y)\n",
627
+ " z = self.dense_0(z)\n",
628
+ " z = self.dense_1(z)\n",
629
+ " z = self.dropout_2(z, training=training)\n",
630
+ "\n",
631
+ " return self.add_2([y, z])"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "markdown",
636
+ "metadata": {},
637
+ "source": [
638
+ "## Model"
639
+ ]
640
+ },
641
+ {
642
+ "cell_type": "markdown",
643
+ "metadata": {},
644
+ "source": [
645
+ "### Transformer"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "execution_count": null,
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": [
654
+ "class DailyTransformer(Model):\n",
655
+ " def __init__(\n",
656
+ " self,\n",
657
+ " num_encoder_layers,\n",
658
+ " num_decoder_layers,\n",
659
+ " embed_dim,\n",
660
+ " mlp_dim,\n",
661
+ " num_heads,\n",
662
+ " num_outputs,\n",
663
+ " dropout_rate,\n",
664
+ " attention_dropout_rate,\n",
665
+ " **kwargs\n",
666
+ " ):\n",
667
+ " super(DailyTransformer, self).__init__(**kwargs)\n",
668
+ "\n",
669
+ " # Input (normalization of RAW measurements)\n",
670
+ " self.input_norm_enc = Normalization(invert=False)\n",
671
+ " self.input_norm_dec1 = Normalization(invert=False)\n",
672
+ " self.input_norm_dec2 = Normalization(invert=True)\n",
673
+ "\n",
674
+ " # Input\n",
675
+ " self.pos_embs_0 = PositionalEmbedding(embed_dim, dropout_rate)\n",
676
+ " self.pos_embs_1 = PositionalEmbedding(embed_dim, dropout_rate)\n",
677
+ "\n",
678
+ " # Encoder\n",
679
+ " self.enc_layers = [\n",
680
+ " Encoder(embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate)\n",
681
+ " for _ in range(num_encoder_layers)\n",
682
+ " ]\n",
683
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
684
+ "\n",
685
+ " # Decoder\n",
686
+ " self.dec_layers = [\n",
687
+ " Decoder(embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate)\n",
688
+ " for _ in range(num_decoder_layers)\n",
689
+ " ]\n",
690
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
691
+ "\n",
692
+ " # Output\n",
693
+ " self.final_layer = Dense(\n",
694
+ " units=num_outputs,\n",
695
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
696
+ " )\n",
697
+ "\n",
698
+ " def call(self, inputs, training):\n",
699
+ " inputs, targets = inputs\n",
700
+ "\n",
701
+ " # Encoder input\n",
702
+ " x_e = self.input_norm_enc(inputs)\n",
703
+ " x_e = self.pos_embs_0(x_e, training=training)\n",
704
+ "\n",
705
+ " # Encoder\n",
706
+ " for layer in self.enc_layers:\n",
707
+ " x_e = layer(x_e, training=training)\n",
708
+ " x_e = self.norm_0(x_e)\n",
709
+ "\n",
710
+ " # Decoder input\n",
711
+ " x_d = self.input_norm_dec1(targets)\n",
712
+ " x_d = self.pos_embs_1(x_d, training=training)\n",
713
+ "\n",
714
+ " # Decoder\n",
715
+ " for layer in self.dec_layers:\n",
716
+ " x_d = layer(x_d, x_e, training=training)\n",
717
+ " x_d = self.norm_1(x_d)\n",
718
+ "\n",
719
+ " # Output\n",
720
+ " final_output = self.final_layer(x_d)\n",
721
+ " final_output = self.input_norm_dec2(final_output)\n",
722
+ "\n",
723
+ " return final_output\n",
724
+ "\n",
725
+ " def train_step(self, inputs):\n",
726
+ " inputs, targets = inputs\n",
727
+ " inputs = inputs[:, :-1]\n",
728
+ " targets_inputs = targets[:, :-1]\n",
729
+ " targets_real = targets[:, 1:, :, -1:]\n",
730
+ "\n",
731
+ " with tf.GradientTape() as tape:\n",
732
+ " y_pred = self([inputs, targets_inputs], training=True)\n",
733
+ " loss = self.compiled_loss(targets_real, y_pred, regularization_losses=self.losses)\n",
734
+ "\n",
735
+ " print(y_pred)\n",
736
+ " print(targets_real)\n",
737
+ "\n",
738
+ " # Compute gradients\n",
739
+ " trainable_vars = self.trainable_variables\n",
740
+ " gradients = tape.gradient(loss, trainable_vars)\n",
741
+ "\n",
742
+ " # Update weights\n",
743
+ " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
744
+ "\n",
745
+ " # Update metrics (includes the metric that tracks the loss)\n",
746
+ " self.compiled_metrics.update_state(targets_real[:, -1], y_pred[:, -1])\n",
747
+ "\n",
748
+ " # Return a dict mapping metric names to current value\n",
749
+ " return {m.name: m.result() for m in self.metrics}\n",
750
+ " \n",
751
+ " def test_step(self, inputs):\n",
752
+ " inputs, targets = inputs\n",
753
+ " inputs = inputs[:, :-1]\n",
754
+ " targets_inputs = targets[:, :-1]\n",
755
+ " targets_real = targets[:, 1:, :, -1:]\n",
756
+ "\n",
757
+ " # Compute predictions\n",
758
+ " y_pred = self([inputs, targets_inputs], training=False)\n",
759
+ "\n",
760
+ " # Updates the metrics tracking the loss\n",
761
+ " self.compiled_loss(targets_real, y_pred, regularization_losses=self.losses)\n",
762
+ "\n",
763
+ " # Update the metrics\n",
764
+ " self.compiled_metrics.update_state(targets_real[:, -1], y_pred[:, -1])\n",
765
+ "\n",
766
+ " # Return a dict mapping metric names to current value\n",
767
+ " # Note that it will include the loss (tracked in self.metrics)\n",
768
+ " return {m.name: m.result() for m in self.metrics}"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "markdown",
773
+ "metadata": {},
774
+ "source": [
775
+ "### Simulator"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": null,
781
+ "metadata": {},
782
+ "outputs": [],
783
+ "source": [
784
+ "class Simulator(tf.Module):\n",
785
+ " def __init__(self, transformer):\n",
786
+ " self.transformer = transformer\n",
787
+ " self.pi = tf.constant(np.pi)\n",
788
+ "\n",
789
+ " def __call__(self, inputs, horizon_length):\n",
790
+ " inputs, targets = inputs\n",
791
+ " output_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)\n",
792
+ "\n",
793
+ " for i in tf.range(horizon_length):\n",
794
+ " tar = targets[:, i:]\n",
795
+ " #print(\"target_old:\", tar[0])\n",
796
+ " \n",
797
+ " # Concatenate history with the predicted future\n",
798
+ " if i > 0:\n",
799
+ " output = tf.transpose(output_array.stack(), perm=[1, 0, 2, 3])\n",
800
+ " if i > tf.shape(inputs)[1]:\n",
801
+ " tar = tf.concat([tar, output[:, (i - tf.shape(inputs)[1]):]], axis=1)\n",
802
+ " else:\n",
803
+ " tar = tf.concat([tar, output], axis=1)\n",
804
+ " #print(\"target_new[\", i, \"]:\", tar[0])\n",
805
+ "\n",
806
+ " #print(\"day sin/cos_OLD:\", tar[0, -1, 0, :-1])\n",
807
+ "\n",
808
+ " day = (tf.atan2(tar[:, -1, :, 0], tar[:, -1, :, 1]) * 183.0) / self.pi\n",
809
+ " day = tf.round(tf.where(day > 0, day, day + 366))\n",
810
+ " \n",
811
+ " day_sin = tf.expand_dims(tf.sin(2.0 * self.pi * (day + 1) / 366.0), axis=-1)\n",
812
+ " day_cos = tf.expand_dims(tf.cos(2.0 * self.pi * (day + 1) / 366.0), axis=-1)\n",
813
+ "\n",
814
+ " #print(\"day: \", day)\n",
815
+ " #print(\"day sin/cos_NEW:\", day_sin[0], day_cos[0])\n",
816
+ "\n",
817
+ " predictions = self.transformer([inputs, tar], training=False)\n",
818
+ " #print(\"predictions: \", predictions[0])\n",
819
+ "\n",
820
+ " if i == 0:\n",
821
+ " zero_predictions = predictions[:, :-1]\n",
822
+ "\n",
823
+ " # concatentate the prediction to the output which is given to the decoder as its input\n",
824
+ " output_array = output_array.write(i, tf.concat([day_sin, day_cos, predictions[:, -1]], axis=-1))\n",
825
+ "\n",
826
+ " output = tf.transpose(output_array.stack(), perm=[1, 0, 2, 3])\n",
827
+ " #print(output.shape)\n",
828
+ "\n",
829
+ " return tf.concat([zero_predictions, output[:, :, :, -1:]], axis=1)"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "markdown",
834
+ "metadata": {},
835
+ "source": [
836
+ "## Dataset"
837
+ ]
838
+ },
839
+ {
840
+ "cell_type": "code",
841
+ "execution_count": null,
842
+ "metadata": {},
843
+ "outputs": [],
844
+ "source": [
845
+ "df_X = pd.read_csv(\"./dataset/1984_2022/X_all_daily.csv\")\n",
846
+ "df_y_daily = pd.read_csv(\"./dataset/1984_2022/y_all_daily.csv\")\n",
847
+ "\n",
848
+ "num_of_patches = df_X['Name'].nunique()\n",
849
+ "\n",
850
+ "df_X = df_X.drop(\n",
851
+ " columns=['DateTime', 'Name', 'Latitude', 'Longitude'] +\n",
852
+ " [c for c in df_X.columns if c[:9] == 'WindSpeed'] +\n",
853
+ " [c for c in df_X.columns if c[:12] == 'WindSpeedMin'] +\n",
854
+ " [c for c in df_X.columns if c[:12] == 'WindSpeedMax'] +\n",
855
+ " [c for c in df_X.columns if c[:13] == 'WindDirection']\n",
856
+ ")\n",
857
+ "df_y_daily = df_y_daily.drop(\n",
858
+ " columns=['DateTime', 'Name', 'Latitude', 'Longitude'] +\n",
859
+ " [c for c in df_y_daily.columns if c[:9] == 'WindSpeed'] +\n",
860
+ " [c for c in df_y_daily.columns if c[:12] == 'WindSpeedMin'] +\n",
861
+ " [c for c in df_y_daily.columns if c[:12] == 'WindSpeedMax'] +\n",
862
+ " [c for c in df_y_daily.columns if c[:13] == 'WindDirection']\n",
863
+ ")\n",
864
+ "\n",
865
+ "loc_names = [\n",
866
+ " \"54 MW PV SOLAR POWER PLANT\",\n",
867
+ " \"5MW Solar Power Plant Varroc\",\n",
868
+ " \"Adani Green Energy Tamilnadu Limited\",\n",
869
+ " \"Arete Elena Energy Pvt Ltd\",\n",
870
+ " \"Bitta Solar Power Plant\",\n",
871
+ " \"Charanka Solar Park\",\n",
872
+ " \"Chennai Metropolitan Area\",\n",
873
+ " \"Ctrls Data Center Mumbai\",\n",
874
+ " \"Indira Paryavaran Bhawan\",\n",
875
+ " \"Kurnool Ultra Mega Solar Park\",\n",
876
+ " \"Pavagada Solar Park\",\n",
877
+ " \"Rewa Ultra Mega Solar\",\n",
878
+ " \"Solar Power Plant Chandasar\",\n",
879
+ " \"Solar Power Plant Khera Silajit\",\n",
880
+ " \"Solar power plant Koppal\",\n",
881
+ " \"Target 1\",\n",
882
+ " \"Target 2\",\n",
883
+ " \"Welspun Solar MP project\",\n",
884
+ "]"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": null,
890
+ "metadata": {},
891
+ "outputs": [],
892
+ "source": [
893
+ "print(df_X.head())\n",
894
+ "print(df_y_daily.head())"
895
+ ]
896
+ },
897
+ {
898
+ "cell_type": "code",
899
+ "execution_count": null,
900
+ "metadata": {},
901
+ "outputs": [],
902
+ "source": [
903
+ "def make_dataset(data, sequence_length, sequence_stride, sampling_rate):\n",
904
+ " def make_window(data):\n",
905
+ " dataset = tf.data.Dataset.from_tensor_slices(data)\n",
906
+ " dataset = dataset.window(sequence_length, shift=sequence_stride, stride=sampling_rate, drop_remainder=True)\n",
907
+ " dataset = dataset.flat_map(lambda x: x.batch(sequence_length, drop_remainder=True)) \n",
908
+ " return dataset\n",
909
+ "\n",
910
+ " data = np.array(data, dtype=np.float32)\n",
911
+ " data = np.reshape(data, (-1, num_of_patches, data.shape[-1]))\n",
912
+ "\n",
913
+ " # Split the data\n",
914
+ " # (80%, 10%, 10%)\n",
915
+ " n = data.shape[0]\n",
916
+ " n_train = int(n*0.8)\n",
917
+ " n_val = int(n*0.9)\n",
918
+ " train_data = data[0:n_train]\n",
919
+ " val_data = data[n_train:n_val]\n",
920
+ " test_data = data[n_val:]\n",
921
+ "\n",
922
+ " return (\n",
923
+ " (n_train, make_window(train_data)),\n",
924
+ " (n_val - n_train, make_window(val_data)),\n",
925
+ " make_window(test_data)\n",
926
+ " )\n",
927
+ "\n",
928
+ "def merge_dataset(datasets, batch_size, shuffle):\n",
929
+ " dataset = tf.data.Dataset.zip(datasets)\n",
930
+ " dataset = dataset.prefetch(tf.data.AUTOTUNE)\n",
931
+ "\n",
932
+ " if shuffle:\n",
933
+ " # Shuffle locally at each iteration\n",
934
+ " dataset = dataset.shuffle(buffer_size=1000)\n",
935
+ " dataset = dataset.batch(batch_size)\n",
936
+ " \n",
937
+ " return dataset"
938
+ ]
939
+ },
940
+ {
941
+ "cell_type": "markdown",
942
+ "metadata": {},
943
+ "source": [
944
+ "## Simulation"
945
+ ]
946
+ },
947
+ {
948
+ "cell_type": "code",
949
+ "execution_count": null,
950
+ "metadata": {},
951
+ "outputs": [],
952
+ "source": [
953
+ "horizon = 7\n",
954
+ "window_size = 7\n",
955
+ "batch_size = 32\n",
956
+ "\n",
957
+ "_, _, test_X_ds = make_dataset(df_X, (window_size + horizon), 1, 1)\n",
958
+ "_, _, test_y_daily_ds = make_dataset(df_y_daily, (window_size + horizon), 1, 1)\n",
959
+ "\n",
960
+ "test_ds = merge_dataset(\n",
961
+ " (test_X_ds, test_y_daily_ds),\n",
962
+ " batch_size,\n",
963
+ " shuffle=False,\n",
964
+ ")\n",
965
+ "\n",
966
+ "daily_model = DailyTransformer(\n",
967
+ " attention_dropout_rate=0.25,\n",
968
+ " dropout_rate=0.15,\n",
969
+ " embed_dim=64,\n",
970
+ " mlp_dim=256,\n",
971
+ " num_decoder_layers=6,\n",
972
+ " num_encoder_layers=3,\n",
973
+ " num_heads=6,\n",
974
+ " num_outputs=1,\n",
975
+ ")\n",
976
+ "daily_model.build([(None, window_size, num_of_patches, 302), (None, window_size, num_of_patches, 3)])\n",
977
+ "daily_model.load_weights(\"./models/model-best.h5\")\n",
978
+ "simulator = Simulator(daily_model)\n",
979
+ "\n",
980
+ "print(daily_model.input_norm_enc.variables)\n",
981
+ "print(daily_model.input_norm_dec1.variables)\n",
982
+ "print(daily_model.input_norm_dec2.variables)"
983
+ ]
984
+ },
985
+ {
986
+ "cell_type": "code",
987
+ "execution_count": null,
988
+ "metadata": {},
989
+ "outputs": [],
990
+ "source": [
991
+ "patch_similarity_plot(daily_model.pos_embs_0.spatial_position[0, 0])\n",
992
+ "patch_similarity_plot(daily_model.pos_embs_1.spatial_position[0, 0])\n",
993
+ "\n",
994
+ "timestep_similarity_plot(daily_model.pos_embs_0.temporal_position[0, :, 0])\n",
995
+ "timestep_similarity_plot(daily_model.pos_embs_1.temporal_position[0, :, 0])"
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "markdown",
1000
+ "metadata": {},
1001
+ "source": [
1002
+ "### Results"
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": null,
1008
+ "metadata": {},
1009
+ "outputs": [],
1010
+ "source": [
1011
+ "metrics = [MeanSquaredError(), RootMeanSquaredError(), MeanAbsoluteError(), MeanAbsolutePercentageError(), RSquare()]\n",
1012
+ "\n",
1013
+ "# Location 1 = 15 (64.67 % na 4 dni), (80.6 % na 1 den)\n",
1014
+ "# Location 2 = 16 (69.8 % na 4 dni), (83.67 % na 1 den)\n",
1015
+ "\n",
1016
+ "# Chennai = 6 (69.8 % na 4 dni), (83.67 % na 1 den)\n",
1017
+ "# Mumbai = 7 (69.8 % na 4 dni), (83.67 % na 1 den)\n",
1018
+ "\n",
1019
+ "for loc in range(num_of_patches):\n",
1020
+ " print(\"Location: \", loc_names[loc])\n",
1021
+ " print(\"-----------------------------------------------------\")\n",
1022
+ " for inputs in test_ds:\n",
1023
+ " inputs, targets = inputs\n",
1024
+ " inputs = inputs[:, :-horizon]\n",
1025
+ " targets_inputs = targets[:, :-horizon]\n",
1026
+ " targets_real = targets[:, 1:, loc, -1:]\n",
1027
+ "\n",
1028
+ " #y_pred = daily_model([inputs, targets_inputs], training=False)\n",
1029
+ " y_pred = simulator([inputs, targets_inputs], horizon_length=horizon)\n",
1030
+ "\n",
1031
+ " # Update the metrics\n",
1032
+ " for m in metrics:\n",
1033
+ " m.update_state(targets_real, y_pred[:, :, loc, -1:])\n",
1034
+ "\n",
1035
+ " # visualize the last results\n",
1036
+ " plot_prediction(targets, y_pred)\n",
1037
+ "\n",
1038
+ " print({m.name: m.result() for m in metrics}, \"\\n\")\n",
1039
+ " for m in metrics:\n",
1040
+ " m.reset_states()"
1041
+ ]
1042
+ }
1043
+ ],
1044
+ "metadata": {
1045
+ "kernelspec": {
1046
+ "display_name": "Python 3.9.10 ('base')",
1047
+ "language": "python",
1048
+ "name": "python3"
1049
+ },
1050
+ "language_info": {
1051
+ "codemirror_mode": {
1052
+ "name": "ipython",
1053
+ "version": 3
1054
+ },
1055
+ "file_extension": ".py",
1056
+ "mimetype": "text/x-python",
1057
+ "name": "python",
1058
+ "nbconvert_exporter": "python",
1059
+ "pygments_lexer": "ipython3",
1060
+ "version": "3.9.10"
1061
+ },
1062
+ "orig_nbformat": 4,
1063
+ "vscode": {
1064
+ "interpreter": {
1065
+ "hash": "9185113d2128201d66faecd4f34fb34e89a635073a034991399523e584519355"
1066
+ }
1067
+ }
1068
+ },
1069
+ "nbformat": 4,
1070
+ "nbformat_minor": 2
1071
+ }
Training.ipynb ADDED
@@ -0,0 +1,1433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "AMmP-w4YRzBU"
7
+ },
8
+ "source": [
9
+ "# Solar Transformer"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "colab": {
17
+ "base_uri": "https://localhost:8080/"
18
+ },
19
+ "id": "7Z0GTvvpRzBx",
20
+ "outputId": "33adb728-76ee-48e4-b388-89717bca8482"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "!pip install wandb tensorflow_probability tensorflow_addons"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {
31
+ "id": "EMUn36ELRzCJ"
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "from tensorflow.keras.layers import Add, Dense, Dropout, Layer, LayerNormalization, MultiHeadAttention\n",
36
+ "from tensorflow.keras.models import Model\n",
37
+ "from tensorflow.keras.initializers import TruncatedNormal\n",
38
+ "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, Callback\n",
39
+ "from tensorflow.keras.optimizers import Adam\n",
40
+ "from tensorflow.keras.metrics import MeanSquaredError, RootMeanSquaredError, MeanAbsoluteError\n",
41
+ "from tensorflow_addons.metrics import RSquare\n",
42
+ "from wandb.keras import WandbCallback\n",
43
+ "\n",
44
+ "import math\n",
45
+ "import wandb\n",
46
+ "import pandas as pd\n",
47
+ "import tensorflow as tf\n",
48
+ "import tensorflow_probability as tfp\n",
49
+ "import tensorflow_addons as tfa\n",
50
+ "import numpy as np\n",
51
+ "import matplotlib.pyplot as plt\n",
52
+ "import seaborn as sns"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {
59
+ "colab": {
60
+ "base_uri": "https://localhost:8080/"
61
+ },
62
+ "id": "VWPvxWIdKfRM",
63
+ "outputId": "3e04ac70-9036-4b16-daca-28b2ef0707cd"
64
+ },
65
+ "outputs": [],
66
+ "source": [
67
+ "from google.colab import drive\n",
68
+ "drive.mount('/content/drive')"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {
74
+ "id": "XgDBs9_3l4uD"
75
+ },
76
+ "source": [
77
+ "## Plotting"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "id": "TDdF5YM4l4Au"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "def plot_4d(matrix):\n",
89
+ " fig = plt.figure(figsize=(10, 20), dpi=300)\n",
90
+ " plt.title(\"Attention heatmap\")\n",
91
+ "\n",
92
+ " # create grid\n",
93
+ " x = np.arange(0, matrix.shape[0], 1, dtype=np.int32) # timesteps\n",
94
+ " y = np.arange(0, matrix.shape[1], 1, dtype=np.int32) # patches\n",
95
+ " z = np.arange(0, matrix.shape[2], 1, dtype=np.int32) # timesteps\n",
96
+ " X, Y, Z = np.meshgrid(x, y, z)\n",
97
+ "\n",
98
+ " X = X.transpose([1, 0, 2])\n",
99
+ " Y = Y.transpose([1, 0, 2])\n",
100
+ " Z = Z.transpose([1, 0, 2])\n",
101
+ "\n",
102
+ " for I in range(matrix.shape[3]):\n",
103
+ " # Plot\n",
104
+ " ax = plt.subplot(5, 5, I+1, projection=\"3d\")\n",
105
+ " ax.scatter3D(X, Y, Z, c=matrix[:, :, :, I], marker='s', s=99, cmap='jet')\n",
106
+ " ax.set_title(\n",
107
+ " f\"{I}-th patch\"\n",
108
+ " )\n",
109
+ " ax.set_xlabel(\"Timestep\")\n",
110
+ " ax.set_ylabel(\"Patch\")\n",
111
+ " ax.set_zlabel(\"Timestep\")\n",
112
+ "\n",
113
+ " plt.show()"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {
120
+ "id": "hdxToaPDOXPg"
121
+ },
122
+ "outputs": [],
123
+ "source": [
124
+ "def patch_similarity_plot(pos):\n",
125
+ " similarity_scores = np.dot(\n",
126
+ " pos, np.transpose(pos)\n",
127
+ " ) / (\n",
128
+ " np.linalg.norm(pos, axis=-1)\n",
129
+ " * np.linalg.norm(pos, axis=-1)\n",
130
+ " )\n",
131
+ "\n",
132
+ " plt.figure(figsize=(7, 7), dpi=300)\n",
133
+ " ax = sns.heatmap(similarity_scores, center=0)\n",
134
+ " ax.set_title(\"Spatial Positional Embedding\")\n",
135
+ " ax.set_xlabel(\"Patch\")\n",
136
+ " ax.set_ylabel(\"Patch\")\n",
137
+ " plt.show()\n",
138
+ "\n",
139
+ "def timestep_similarity_plot(pos):\n",
140
+ " similarity_scores = np.dot(\n",
141
+ " pos, np.transpose(pos)\n",
142
+ " ) / (\n",
143
+ " np.linalg.norm(pos, axis=-1)\n",
144
+ " * np.linalg.norm(pos, axis=-1)\n",
145
+ " )\n",
146
+ "\n",
147
+ " plt.figure(figsize=(7, 7), dpi=300)\n",
148
+ " ax = sns.heatmap(similarity_scores, center=0)\n",
149
+ " ax.set_title(\"Temporal Positional Embedding\")\n",
150
+ " ax.set_xlabel(\"Timestep\")\n",
151
+ " ax.set_ylabel(\"Timestep\")\n",
152
+ " plt.show()"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {
159
+ "id": "Ky-VvsYT2aSz"
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "def plot_prediction(inputs, model, max_subplots=3):\n",
164
+ " inputs, targets = inputs\n",
165
+ " inputs = inputs[:, :-1]\n",
166
+ "\n",
167
+ " plt.figure(figsize=(12, 15))\n",
168
+ " max_n = min(max_subplots, len(targets))\n",
169
+ " for n in range(max_n):\n",
170
+ " # input\n",
171
+ " plt.subplot(max_n, 1, n+1)\n",
172
+ " plt.ylabel('Solar irradiance [kW-hr/m^2/day]')\n",
173
+ " plt.plot(np.arange(targets.shape[1]-1), targets[n, :-1, 0, -1], label='Inputs', marker='.', zorder=-10)\n",
174
+ "\n",
175
+ " # real\n",
176
+ " plt.scatter(np.arange(1, targets.shape[1]), targets[n, 1:, 0, -1], edgecolors='k', label='Targets', c='#2cb01d', s=64)\n",
177
+ " \n",
178
+ " # predicted\n",
179
+ " predictions = model([inputs, targets[:, :-1]], training=False)\n",
180
+ " plt.scatter(np.arange(1, targets.shape[1]), predictions[n, :, 0, -1], marker='X', edgecolors='k', label='Predictions', c='#fe7e0f', s=64)\n",
181
+ "\n",
182
+ " if n == 0:\n",
183
+ " plt.legend()\n",
184
+ "\n",
185
+ " plt.xlabel('Time [day]')\n",
186
+ " plt.show()"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "I4i5bIzkRzC2"
193
+ },
194
+ "source": [
195
+ "## Init logger"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {
202
+ "colab": {
203
+ "base_uri": "https://localhost:8080/",
204
+ "height": 106
205
+ },
206
+ "id": "ABI9_YirRzDA",
207
+ "outputId": "1c9c6c2d-a8dd-4375-fe49-62e59da6969a"
208
+ },
209
+ "outputs": [],
210
+ "source": [
211
+ "wandb.login()\n",
212
+ "\n",
213
+ "sweep_config = {\n",
214
+ " 'method': 'grid',\n",
215
+ " 'metric': {\n",
216
+ " 'goal': 'minimize',\n",
217
+ " 'name': 'val_mean_squared_error'\n",
218
+ " },\n",
219
+ " 'parameters': {\n",
220
+ " 'epochs': {\n",
221
+ " 'value': 1000\n",
222
+ " },\n",
223
+ " 'num_encoder_layers': {\n",
224
+ " 'value': 3\n",
225
+ " },\n",
226
+ " 'num_decoder_layers': {\n",
227
+ " 'value': 6\n",
228
+ " },\n",
229
+ " 'embed_layer_size': {\n",
230
+ " 'value': 64\n",
231
+ " },\n",
232
+ " 'fc_layer_size': {\n",
233
+ " 'value': 256\n",
234
+ " },\n",
235
+ " 'num_heads': {\n",
236
+ " 'value': 8\n",
237
+ " },\n",
238
+ " 'dropout': {\n",
239
+ " 'value': 0.15\n",
240
+ " },\n",
241
+ " 'attention_dropout': {\n",
242
+ " 'value': 0.25\n",
243
+ " },\n",
244
+ " 'optimizer': {\n",
245
+ " 'value': 'adamw'\n",
246
+ " },\n",
247
+ " 'global_clipnorm': {\n",
248
+ " 'value': 2.0\n",
249
+ " },\n",
250
+ " 'learning_rate': {\n",
251
+ " 'value': 0.005\n",
252
+ " },\n",
253
+ " 'weight_decay': {\n",
254
+ " 'value': 0.00001\n",
255
+ " },\n",
256
+ " 'warmup_steps': {\n",
257
+ " 'value': 70\n",
258
+ " },\n",
259
+ " 'window_size': {\n",
260
+ " 'value': 7 # every 7 days\n",
261
+ " },\n",
262
+ " 'batch_size': {\n",
263
+ " 'value': 32\n",
264
+ " },\n",
265
+ " }\n",
266
+ "}\n",
267
+ "\n",
268
+ "sweep_id = wandb.sweep(sweep_config, project=\"solar-transformer\")\n",
269
+ "!export WANDB_AGENT_MAX_INITIAL_FAILURES=1024"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {
275
+ "id": "oQyRcTjTRzEE"
276
+ },
277
+ "source": [
278
+ "## Layer"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {
285
+ "id": "pbmmbGbfRzEO"
286
+ },
287
+ "outputs": [],
288
+ "source": [
289
+ "class PositionalEmbedding(Layer):\n",
290
+ " def __init__(self, units, dropout_rate, **kwargs):\n",
291
+ " super(PositionalEmbedding, self).__init__(**kwargs)\n",
292
+ "\n",
293
+ " self.units = units\n",
294
+ "\n",
295
+ " self.projection = Dense(units, kernel_initializer=TruncatedNormal(stddev=0.02))\n",
296
+ " self.dropout = Dropout(rate=dropout_rate)\n",
297
+ "\n",
298
+ " def build(self, input_shape):\n",
299
+ " super(PositionalEmbedding, self).build(input_shape)\n",
300
+ "\n",
301
+ " print(\"pos_embbeding: \", input_shape)\n",
302
+ " self.temporal_position = self.add_weight(\n",
303
+ " name=\"temporal_position\",\n",
304
+ " shape=(1, input_shape[1], 1, self.units),\n",
305
+ " initializer=TruncatedNormal(stddev=0.02),\n",
306
+ " trainable=True,\n",
307
+ " )\n",
308
+ " self.spatial_position = self.add_weight(\n",
309
+ " name=\"spatial_position\",\n",
310
+ " shape=(1, 1, input_shape[2], self.units),\n",
311
+ " initializer=TruncatedNormal(stddev=0.02),\n",
312
+ " trainable=True,\n",
313
+ " )\n",
314
+ "\n",
315
+ " def call(self, inputs, training):\n",
316
+ " x = self.projection(inputs)\n",
317
+ " x += self.temporal_position\n",
318
+ " x += self.spatial_position\n",
319
+ "\n",
320
+ " return self.dropout(x, training=training)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {
327
+ "id": "uswJFdtUVk8Q"
328
+ },
329
+ "outputs": [],
330
+ "source": [
331
+ "class Normalization(tf.keras.layers.experimental.preprocessing.PreprocessingLayer):\n",
332
+ " \"\"\"A preprocessing layer which normalizes continuous features.\n",
333
+ " This layer will shift and scale inputs into a distribution centered around\n",
334
+ " 0 with standard deviation 1. It accomplishes this by precomputing the mean\n",
335
+ " and variance of the data, and calling `(input - mean) / sqrt(var)` at\n",
336
+ " runtime.\n",
337
+ " The mean and variance values for the layer must be either supplied on\n",
338
+ " construction or learned via `adapt()`. `adapt()` will compute the mean and\n",
339
+ " variance of the data and store them as the layer's weights. `adapt()` should\n",
340
+ " be called before `fit()`, `evaluate()`, or `predict()`.\n",
341
+ " For an overview and full list of preprocessing layers, see the preprocessing\n",
342
+ " [guide](https://www.tensorflow.org/guide/keras/preprocessing_layers).\n",
343
+ " Args:\n",
344
+ " axis: Integer, tuple of integers, or None. The axis or axes that should\n",
345
+ " have a separate mean and variance for each index in the shape. For\n",
346
+ " example, if shape is `(None, 5)` and `axis=1`, the layer will track 5\n",
347
+ " separate mean and variance values for the last axis. If `axis` is set\n",
348
+ " to `None`, the layer will normalize all elements in the input by a\n",
349
+ " scalar mean and variance. Defaults to -1, where the last axis of the\n",
350
+ " input is assumed to be a feature dimension and is normalized per\n",
351
+ " index. Note that in the specific case of batched scalar inputs where\n",
352
+ " the only axis is the batch axis, the default will normalize each index\n",
353
+ " in the batch separately. In this case, consider passing `axis=None`.\n",
354
+ " mean: The mean value(s) to use during normalization. The passed value(s)\n",
355
+ " will be broadcast to the shape of the kept axes above; if the value(s)\n",
356
+ " cannot be broadcast, an error will be raised when this layer's\n",
357
+ " `build()` method is called.\n",
358
+ " variance: The variance value(s) to use during normalization. The passed\n",
359
+ " value(s) will be broadcast to the shape of the kept axes above; if the\n",
360
+ " value(s) cannot be broadcast, an error will be raised when this\n",
361
+ " layer's `build()` method is called.\n",
362
+ " invert: If True, this layer will apply the inverse transformation\n",
363
+ " to its inputs: it would turn a normalized input back into its\n",
364
+ " original form.\n",
365
+ " Examples:\n",
366
+ " Calculate a global mean and variance by analyzing the dataset in `adapt()`.\n",
367
+ " >>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')\n",
368
+ " >>> input_data = np.array([1., 2., 3.], dtype='float32')\n",
369
+ " >>> layer = tf.keras.layers.Normalization(axis=None)\n",
370
+ " >>> layer.adapt(adapt_data)\n",
371
+ " >>> layer(input_data)\n",
372
+ " <tf.Tensor: shape=(3,), dtype=float32, numpy=\n",
373
+ " array([-1.4142135, -0.70710677, 0.], dtype=float32)>\n",
374
+ " Calculate a mean and variance for each index on the last axis.\n",
375
+ " >>> adapt_data = np.array([[0., 7., 4.],\n",
376
+ " ... [2., 9., 6.],\n",
377
+ " ... [0., 7., 4.],\n",
378
+ " ... [2., 9., 6.]], dtype='float32')\n",
379
+ " >>> input_data = np.array([[0., 7., 4.]], dtype='float32')\n",
380
+ " >>> layer = tf.keras.layers.Normalization(axis=-1)\n",
381
+ " >>> layer.adapt(adapt_data)\n",
382
+ " >>> layer(input_data)\n",
383
+ " <tf.Tensor: shape=(1, 3), dtype=float32, numpy=\n",
384
+ " array([-1., -1., -1.], dtype=float32)>\n",
385
+ " Pass the mean and variance directly.\n",
386
+ " >>> input_data = np.array([[1.], [2.], [3.]], dtype='float32')\n",
387
+ " >>> layer = tf.keras.layers.Normalization(mean=3., variance=2.)\n",
388
+ " >>> layer(input_data)\n",
389
+ " <tf.Tensor: shape=(3, 1), dtype=float32, numpy=\n",
390
+ " array([[-1.4142135 ],\n",
391
+ " [-0.70710677],\n",
392
+ " [ 0. ]], dtype=float32)>\n",
393
+ " Use the layer to de-normalize inputs (after adapting the layer).\n",
394
+ " >>> adapt_data = np.array([[0., 7., 4.],\n",
395
+ " ... [2., 9., 6.],\n",
396
+ " ... [0., 7., 4.],\n",
397
+ " ... [2., 9., 6.]], dtype='float32')\n",
398
+ " >>> input_data = np.array([[1., 2., 3.]], dtype='float32')\n",
399
+ " >>> layer = tf.keras.layers.Normalization(axis=-1, invert=True)\n",
400
+ " >>> layer.adapt(adapt_data)\n",
401
+ " >>> layer(input_data)\n",
402
+ " <tf.Tensor: shape=(1, 3), dtype=float32, numpy=\n",
403
+ " array([2., 10., 8.], dtype=float32)>\n",
404
+ " \"\"\"\n",
405
+ "\n",
406
+ " def __init__(\n",
407
+ " self, axis=-1, mean=None, variance=None, invert=False, **kwargs\n",
408
+ " ):\n",
409
+ " super().__init__(**kwargs)\n",
410
+ "\n",
411
+ " # Standardize `axis` to a tuple.\n",
412
+ " if axis is None:\n",
413
+ " axis = ()\n",
414
+ " elif isinstance(axis, int):\n",
415
+ " axis = (axis,)\n",
416
+ " else:\n",
417
+ " axis = tuple(axis)\n",
418
+ " self.axis = axis\n",
419
+ "\n",
420
+ " # Set `mean` and `variance` if passed.\n",
421
+ " if isinstance(mean, tf.Variable):\n",
422
+ " raise ValueError(\n",
423
+ " \"Normalization does not support passing a Variable \"\n",
424
+ " \"for the `mean` init arg.\"\n",
425
+ " )\n",
426
+ " if isinstance(variance, tf.Variable):\n",
427
+ " raise ValueError(\n",
428
+ " \"Normalization does not support passing a Variable \"\n",
429
+ " \"for the `variance` init arg.\"\n",
430
+ " )\n",
431
+ " if (mean is not None) != (variance is not None):\n",
432
+ " raise ValueError(\n",
433
+ " \"When setting values directly, both `mean` and `variance` \"\n",
434
+ " \"must be set. Got mean: {} and variance: {}\".format(\n",
435
+ " mean, variance\n",
436
+ " )\n",
437
+ " )\n",
438
+ " self.input_mean = mean\n",
439
+ " self.input_variance = variance\n",
440
+ " self.invert = invert\n",
441
+ "\n",
442
+ " def build(self, input_shape):\n",
443
+ " super().build(input_shape)\n",
444
+ "\n",
445
+ " if isinstance(input_shape, (list, tuple)) and all(\n",
446
+ " isinstance(shape, tf.TensorShape) for shape in input_shape\n",
447
+ " ):\n",
448
+ " raise ValueError(\n",
449
+ " \"Normalization only accepts a single input. If you are \"\n",
450
+ " \"passing a python list or tuple as a single input, \"\n",
451
+ " \"please convert to a numpy array or `tf.Tensor`.\"\n",
452
+ " )\n",
453
+ "\n",
454
+ " input_shape = tf.TensorShape(input_shape).as_list()\n",
455
+ " ndim = len(input_shape)\n",
456
+ "\n",
457
+ " if any(a < -ndim or a >= ndim for a in self.axis):\n",
458
+ " raise ValueError(\n",
459
+ " \"All `axis` values must be in the range [-ndim, ndim). \"\n",
460
+ " \"Found ndim: `{}`, axis: {}\".format(ndim, self.axis)\n",
461
+ " )\n",
462
+ "\n",
463
+ " # Axes to be kept, replacing negative values with positive equivalents.\n",
464
+ " # Sorted to avoid transposing axes.\n",
465
+ " self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])\n",
466
+ " # All axes to be kept should have known shape.\n",
467
+ " for d in self._keep_axis:\n",
468
+ " if input_shape[d] is None:\n",
469
+ " raise ValueError(\n",
470
+ " \"All `axis` values to be kept must have known shape. \"\n",
471
+ " \"Got axis: {}, \"\n",
472
+ " \"input shape: {}, with unknown axis at index: {}\".format(\n",
473
+ " self.axis, input_shape, d\n",
474
+ " )\n",
475
+ " )\n",
476
+ " # Axes to be reduced.\n",
477
+ " self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]\n",
478
+ " # 1 if an axis should be reduced, 0 otherwise.\n",
479
+ " self._reduce_axis_mask = [\n",
480
+ " 0 if d in self._keep_axis else 1 for d in range(ndim)\n",
481
+ " ]\n",
482
+ " # Broadcast any reduced axes.\n",
483
+ " self._broadcast_shape = [\n",
484
+ " input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)\n",
485
+ " ]\n",
486
+ " mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)\n",
487
+ "\n",
488
+ " if self.input_mean is None:\n",
489
+ " self.adapt_mean = self.add_weight(\n",
490
+ " name=\"mean\",\n",
491
+ " shape=mean_and_var_shape,\n",
492
+ " dtype=self.compute_dtype,\n",
493
+ " initializer=\"zeros\",\n",
494
+ " trainable=False,\n",
495
+ " )\n",
496
+ " self.adapt_variance = self.add_weight(\n",
497
+ " name=\"variance\",\n",
498
+ " shape=mean_and_var_shape,\n",
499
+ " dtype=self.compute_dtype,\n",
500
+ " initializer=\"ones\",\n",
501
+ " trainable=False,\n",
502
+ " )\n",
503
+ " self.count = self.add_weight(\n",
504
+ " name=\"count\",\n",
505
+ " shape=(),\n",
506
+ " dtype=tf.int64,\n",
507
+ " initializer=\"zeros\",\n",
508
+ " trainable=False,\n",
509
+ " )\n",
510
+ " self.finalize_state()\n",
511
+ " else:\n",
512
+ " # In the no adapt case, make constant tensors for mean and variance\n",
513
+ " # with proper broadcast shape for use during call.\n",
514
+ " mean = self.input_mean * np.ones(mean_and_var_shape)\n",
515
+ " variance = self.input_variance * np.ones(mean_and_var_shape)\n",
516
+ " mean = tf.reshape(mean, self._broadcast_shape)\n",
517
+ " variance = tf.reshape(variance, self._broadcast_shape)\n",
518
+ " self.mean = tf.cast(mean, self.compute_dtype)\n",
519
+ " self.variance = tf.cast(variance, self.compute_dtype)\n",
520
+ "\n",
521
+ " # We override this method solely to generate a docstring.\n",
522
+ " def adapt(self, data, batch_size=None, steps=None):\n",
523
+ " \"\"\"Computes the mean and variance of values in a dataset.\n",
524
+ " Calling `adapt()` on a `Normalization` layer is an alternative to\n",
525
+ " passing in `mean` and `variance` arguments during layer construction. A\n",
526
+ " `Normalization` layer should always either be adapted over a dataset or\n",
527
+ " passed `mean` and `variance`.\n",
528
+ " During `adapt()`, the layer will compute a `mean` and `variance`\n",
529
+ " separately for each position in each axis specified by the `axis`\n",
530
+ " argument. To calculate a single `mean` and `variance` over the input\n",
531
+ " data, simply pass `axis=None`.\n",
532
+ " In order to make `Normalization` efficient in any distribution context,\n",
533
+ " the computed mean and variance are kept static with respect to any\n",
534
+ " compiled `tf.Graph`s that call the layer. As a consequence, if the layer\n",
535
+ " is adapted a second time, any models using the layer should be\n",
536
+ " re-compiled. For more information see\n",
537
+ " `tf.keras.layers.experimental.preprocessing.PreprocessingLayer.adapt`.\n",
538
+ " `adapt()` is meant only as a single machine utility to compute layer\n",
539
+ " state. To analyze a dataset that cannot fit on a single machine, see\n",
540
+ " [Tensorflow Transform](\n",
541
+ " https://www.tensorflow.org/tfx/transform/get_started)\n",
542
+ " for a multi-machine, map-reduce solution.\n",
543
+ " Arguments:\n",
544
+ " data: The data to train on. It can be passed either as a\n",
545
+ " `tf.data.Dataset`, or as a numpy array.\n",
546
+ " batch_size: Integer or `None`.\n",
547
+ " Number of samples per state update.\n",
548
+ " If unspecified, `batch_size` will default to 32.\n",
549
+ " Do not specify the `batch_size` if your data is in the\n",
550
+ " form of datasets, generators, or `keras.utils.Sequence` instances\n",
551
+ " (since they generate batches).\n",
552
+ " steps: Integer or `None`.\n",
553
+ " Total number of steps (batches of samples)\n",
554
+ " When training with input tensors such as\n",
555
+ " TensorFlow data tensors, the default `None` is equal to\n",
556
+ " the number of samples in your dataset divided by\n",
557
+ " the batch size, or 1 if that cannot be determined. If x is a\n",
558
+ " `tf.data` dataset, and 'steps' is None, the epoch will run until\n",
559
+ " the input dataset is exhausted. When passing an infinitely\n",
560
+ " repeating dataset, you must specify the `steps` argument. This\n",
561
+ " argument is not supported with array inputs.\n",
562
+ " \"\"\"\n",
563
+ " super().adapt(data, batch_size=batch_size, steps=steps)\n",
564
+ "\n",
565
+ " def update_state(self, data):\n",
566
+ " if self.input_mean is not None:\n",
567
+ " raise ValueError(\n",
568
+ " \"Cannot `adapt` a Normalization layer that is initialized with \"\n",
569
+ " \"static `mean` and `variance`, \"\n",
570
+ " \"you passed mean {} and variance {}.\".format(\n",
571
+ " self.input_mean, self.input_variance\n",
572
+ " )\n",
573
+ " )\n",
574
+ "\n",
575
+ " if not self.built:\n",
576
+ " raise RuntimeError(\"`build` must be called before `update_state`.\")\n",
577
+ "\n",
578
+ " data = self._standardize_inputs(data)\n",
579
+ " data = tf.cast(data, self.adapt_mean.dtype)\n",
580
+ " batch_mean, batch_variance = tf.nn.moments(data, axes=self._reduce_axis)\n",
581
+ " batch_shape = tf.shape(data, out_type=self.count.dtype)\n",
582
+ " if self._reduce_axis:\n",
583
+ " batch_reduce_shape = tf.gather(batch_shape, self._reduce_axis)\n",
584
+ " batch_count = tf.reduce_prod(batch_reduce_shape)\n",
585
+ " else:\n",
586
+ " batch_count = 1\n",
587
+ "\n",
588
+ " total_count = batch_count + self.count\n",
589
+ " batch_weight = tf.cast(batch_count, dtype=self.compute_dtype) / tf.cast(\n",
590
+ " total_count, dtype=self.compute_dtype\n",
591
+ " )\n",
592
+ " existing_weight = 1.0 - batch_weight\n",
593
+ "\n",
594
+ " total_mean = (\n",
595
+ " self.adapt_mean * existing_weight + batch_mean * batch_weight\n",
596
+ " )\n",
597
+ " # The variance is computed using the lack-of-fit sum of squares\n",
598
+ " # formula (see\n",
599
+ " # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).\n",
600
+ " total_variance = (\n",
601
+ " self.adapt_variance + (self.adapt_mean - total_mean) ** 2\n",
602
+ " ) * existing_weight + (\n",
603
+ " batch_variance + (batch_mean - total_mean) ** 2\n",
604
+ " ) * batch_weight\n",
605
+ " self.adapt_mean.assign(total_mean)\n",
606
+ " self.adapt_variance.assign(total_variance)\n",
607
+ " self.count.assign(total_count)\n",
608
+ "\n",
609
+ " def reset_state(self):\n",
610
+ " if self.input_mean is not None or not self.built:\n",
611
+ " return\n",
612
+ "\n",
613
+ " self.adapt_mean.assign(tf.zeros_like(self.adapt_mean))\n",
614
+ " self.adapt_variance.assign(tf.ones_like(self.adapt_variance))\n",
615
+ " self.count.assign(tf.zeros_like(self.count))\n",
616
+ "\n",
617
+ " def finalize_state(self):\n",
618
+ " if self.input_mean is not None or not self.built:\n",
619
+ " return\n",
620
+ "\n",
621
+ " # In the adapt case, we make constant tensors for mean and variance with\n",
622
+ " # proper broadcast shape and dtype each time `finalize_state` is called.\n",
623
+ " self.mean = tf.reshape(self.adapt_mean, self._broadcast_shape)\n",
624
+ " self.mean = tf.cast(self.mean, self.compute_dtype)\n",
625
+ " self.variance = tf.reshape(self.adapt_variance, self._broadcast_shape)\n",
626
+ " self.variance = tf.cast(self.variance, self.compute_dtype)\n",
627
+ "\n",
628
+ " def call(self, inputs):\n",
629
+ " inputs = self._standardize_inputs(inputs)\n",
630
+ " # The base layer automatically casts floating-point inputs, but we\n",
631
+ " # explicitly cast here to also allow integer inputs to be passed\n",
632
+ " inputs = tf.cast(inputs, self.compute_dtype)\n",
633
+ " if self.invert:\n",
634
+ " return (inputs + self.mean) * tf.maximum(\n",
635
+ " tf.sqrt(self.variance), tf.keras.backend.epsilon()\n",
636
+ " )\n",
637
+ " else:\n",
638
+ " return (inputs - self.mean) / tf.maximum(\n",
639
+ " tf.sqrt(self.variance), tf.keras.backend.epsilon()\n",
640
+ " )\n",
641
+ "\n",
642
+ " def compute_output_shape(self, input_shape):\n",
643
+ " return input_shape\n",
644
+ "\n",
645
+ " def compute_output_signature(self, input_spec):\n",
646
+ " return input_spec\n",
647
+ "\n",
648
+ " def get_config(self):\n",
649
+ " config = super().get_config()\n",
650
+ " config.update(\n",
651
+ " {\n",
652
+ " \"axis\": self.axis,\n",
653
+ " \"mean\": tf.keras.layers.experimental.preprocessing.preprocessing_utils.utils.listify_tensors(self.input_mean),\n",
654
+ " \"variance\": tf.keras.layers.experimental.preprocessing.preprocessing_utils.utils.listify_tensors(self.input_variance),\n",
655
+ " }\n",
656
+ " )\n",
657
+ " return config\n",
658
+ "\n",
659
+ " def _standardize_inputs(self, inputs):\n",
660
+ " inputs = tf.convert_to_tensor(inputs)\n",
661
+ " if inputs.dtype != self.compute_dtype:\n",
662
+ " inputs = tf.cast(inputs, self.compute_dtype)\n",
663
+ " return inputs"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "execution_count": null,
669
+ "metadata": {
670
+ "id": "rvpW2SbnRzEc"
671
+ },
672
+ "outputs": [],
673
+ "source": [
674
+ "class Encoder(Layer):\n",
675
+ " def __init__(\n",
676
+ " self, embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate, **kwargs\n",
677
+ " ):\n",
678
+ " super(Encoder, self).__init__(**kwargs)\n",
679
+ "\n",
680
+ " # Multi-head Attention\n",
681
+ " self.mha = MultiHeadAttention(\n",
682
+ " num_heads=num_heads,\n",
683
+ " key_dim=embed_dim,\n",
684
+ " dropout=attention_dropout_rate,\n",
685
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
686
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
687
+ " )\n",
688
+ "\n",
689
+ " # Point wise feed forward network\n",
690
+ " self.dense_0 = Dense(\n",
691
+ " units=mlp_dim,\n",
692
+ " activation=\"gelu\",\n",
693
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
694
+ " )\n",
695
+ " self.dense_1 = Dense(\n",
696
+ " units=embed_dim, kernel_initializer=TruncatedNormal(stddev=0.02)\n",
697
+ " )\n",
698
+ "\n",
699
+ " self.dropout_0 = Dropout(rate=dropout_rate)\n",
700
+ " self.dropout_1 = Dropout(rate=dropout_rate)\n",
701
+ "\n",
702
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
703
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
704
+ "\n",
705
+ " self.add_0 = Add()\n",
706
+ " self.add_1 = Add()\n",
707
+ "\n",
708
+ " def call(self, inputs, training):\n",
709
+ " # Attention block\n",
710
+ " x = self.norm_0(inputs)\n",
711
+ " x = self.mha(\n",
712
+ " query=x,\n",
713
+ " key=x,\n",
714
+ " value=x,\n",
715
+ " training=training,\n",
716
+ " )\n",
717
+ " x = self.dropout_0(x, training=training)\n",
718
+ " x = self.add_0([x, inputs])\n",
719
+ "\n",
720
+ " # MLP block\n",
721
+ " y = self.norm_1(x)\n",
722
+ " y = self.dense_0(y)\n",
723
+ " y = self.dense_1(y)\n",
724
+ " y = self.dropout_1(y, training=training)\n",
725
+ "\n",
726
+ " return self.add_1([x, y])"
727
+ ]
728
+ },
729
+ {
730
+ "cell_type": "code",
731
+ "execution_count": null,
732
+ "metadata": {
733
+ "id": "V3n2jEdBRzEo"
734
+ },
735
+ "outputs": [],
736
+ "source": [
737
+ "class Decoder(Layer):\n",
738
+ " def __init__(\n",
739
+ " self, embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate, **kwargs\n",
740
+ " ):\n",
741
+ " super(Decoder, self).__init__(**kwargs)\n",
742
+ "\n",
743
+ " # MultiHeadAttention\n",
744
+ " self.mha_0 = MultiHeadAttention(\n",
745
+ " num_heads=num_heads,\n",
746
+ " key_dim=embed_dim,\n",
747
+ " dropout=attention_dropout_rate,\n",
748
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
749
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
750
+ " )\n",
751
+ " self.mha_1 = MultiHeadAttention(\n",
752
+ " num_heads=num_heads,\n",
753
+ " key_dim=embed_dim,\n",
754
+ " dropout=attention_dropout_rate,\n",
755
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
756
+ " attention_axes=(1, 2), # 2D attention (timestep, patch)\n",
757
+ " )\n",
758
+ "\n",
759
+ " # Point wise feed forward network\n",
760
+ " self.dense_0 = Dense(\n",
761
+ " units=mlp_dim,\n",
762
+ " activation=\"gelu\",\n",
763
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
764
+ " )\n",
765
+ " self.dense_1 = Dense(\n",
766
+ " units=embed_dim, kernel_initializer=TruncatedNormal(stddev=0.02)\n",
767
+ " )\n",
768
+ "\n",
769
+ " self.dropout_0 = Dropout(rate=dropout_rate)\n",
770
+ " self.dropout_1 = Dropout(rate=dropout_rate)\n",
771
+ " self.dropout_2 = Dropout(rate=dropout_rate)\n",
772
+ "\n",
773
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
774
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
775
+ " self.norm_2 = LayerNormalization(epsilon=1e-12)\n",
776
+ "\n",
777
+ " self.add_0 = Add()\n",
778
+ " self.add_1 = Add()\n",
779
+ " self.add_2 = Add()\n",
780
+ "\n",
781
+ " def call(self, inputs, enc_output, training):\n",
782
+ " # Attention block\n",
783
+ " x = self.norm_0(inputs)\n",
784
+ " x = self.mha_0(\n",
785
+ " query=x,\n",
786
+ " key=x,\n",
787
+ " value=x,\n",
788
+ " training=training,\n",
789
+ " )\n",
790
+ " x = self.dropout_0(x, training=training)\n",
791
+ " x = self.add_0([x, inputs])\n",
792
+ "\n",
793
+ " # Attention block\n",
794
+ " y = self.norm_1(x)\n",
795
+ " y = self.mha_1(\n",
796
+ " query=y,\n",
797
+ " key=enc_output,\n",
798
+ " value=enc_output,\n",
799
+ " training=training,\n",
800
+ " )\n",
801
+ " y = self.dropout_1(y, training=training)\n",
802
+ " y = self.add_1([x, y])\n",
803
+ "\n",
804
+ " # MLP block\n",
805
+ " z = self.norm_2(y)\n",
806
+ " z = self.dense_0(z)\n",
807
+ " z = self.dense_1(z)\n",
808
+ " z = self.dropout_2(z, training=training)\n",
809
+ "\n",
810
+ " return self.add_2([y, z])"
811
+ ]
812
+ },
813
+ {
814
+ "cell_type": "markdown",
815
+ "metadata": {
816
+ "id": "7O_O6FKlRzE1"
817
+ },
818
+ "source": [
819
+ "## Model"
820
+ ]
821
+ },
822
+ {
823
+ "cell_type": "code",
824
+ "execution_count": null,
825
+ "metadata": {
826
+ "id": "jiO-dcXURzE8"
827
+ },
828
+ "outputs": [],
829
+ "source": [
830
+ "class DailyTransformer(Model):\n",
831
+ " def __init__(\n",
832
+ " self,\n",
833
+ " num_encoder_layers,\n",
834
+ " num_decoder_layers,\n",
835
+ " embed_dim,\n",
836
+ " mlp_dim,\n",
837
+ " num_heads,\n",
838
+ " num_outputs,\n",
839
+ " dropout_rate,\n",
840
+ " attention_dropout_rate,\n",
841
+ " **kwargs\n",
842
+ " ):\n",
843
+ " super(DailyTransformer, self).__init__(**kwargs)\n",
844
+ "\n",
845
+ " # Input (normalization of RAW measurements)\n",
846
+ " self.input_norm_enc = Normalization(invert=False)\n",
847
+ " self.input_norm_dec1 = Normalization(invert=False)\n",
848
+ " self.input_norm_dec2 = Normalization(invert=True)\n",
849
+ "\n",
850
+ " # Input\n",
851
+ " self.pos_embs_0 = PositionalEmbedding(embed_dim, dropout_rate)\n",
852
+ " self.pos_embs_1 = PositionalEmbedding(embed_dim, dropout_rate)\n",
853
+ "\n",
854
+ " # Encoder\n",
855
+ " self.enc_layers = [\n",
856
+ " Encoder(embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate)\n",
857
+ " for _ in range(num_encoder_layers)\n",
858
+ " ]\n",
859
+ " self.norm_0 = LayerNormalization(epsilon=1e-12)\n",
860
+ "\n",
861
+ " # Decoder\n",
862
+ " self.dec_layers = [\n",
863
+ " Decoder(embed_dim, mlp_dim, num_heads, dropout_rate, attention_dropout_rate)\n",
864
+ " for _ in range(num_decoder_layers)\n",
865
+ " ]\n",
866
+ " self.norm_1 = LayerNormalization(epsilon=1e-12)\n",
867
+ "\n",
868
+ " # Output\n",
869
+ " self.final_layer = Dense(\n",
870
+ " units=num_outputs,\n",
871
+ " kernel_initializer=TruncatedNormal(stddev=0.02),\n",
872
+ " )\n",
873
+ "\n",
874
+ " def call(self, inputs, training):\n",
875
+ " inputs, targets = inputs\n",
876
+ "\n",
877
+ " # Encoder input\n",
878
+ " x_e = self.input_norm_enc(inputs)\n",
879
+ " x_e = self.pos_embs_0(x_e, training=training)\n",
880
+ "\n",
881
+ " # Encoder\n",
882
+ " for layer in self.enc_layers:\n",
883
+ " x_e = layer(x_e, training=training)\n",
884
+ " x_e = self.norm_0(x_e)\n",
885
+ "\n",
886
+ " # Decoder input\n",
887
+ " x_d = self.input_norm_dec1(targets)\n",
888
+ " x_d = self.pos_embs_1(x_d, training=training)\n",
889
+ "\n",
890
+ " # Decoder\n",
891
+ " for layer in self.dec_layers:\n",
892
+ " x_d = layer(x_d, x_e, training=training)\n",
893
+ " x_d = self.norm_1(x_d)\n",
894
+ "\n",
895
+ " # Output\n",
896
+ " final_output = self.final_layer(x_d)\n",
897
+ " final_output = self.input_norm_dec2(final_output)\n",
898
+ "\n",
899
+ " return final_output\n",
900
+ "\n",
901
+ " def train_step(self, inputs):\n",
902
+ " inputs, targets = inputs\n",
903
+ " inputs = inputs[:, :-1]\n",
904
+ " targets_inputs = targets[:, :-1]\n",
905
+ " targets_real = targets[:, 1:, :, -1:]\n",
906
+ "\n",
907
+ " with tf.GradientTape() as tape:\n",
908
+ " y_pred = self([inputs, targets_inputs], training=True)\n",
909
+ " loss = self.compiled_loss(targets_real, y_pred, regularization_losses=self.losses)\n",
910
+ "\n",
911
+ " print(y_pred)\n",
912
+ " print(targets_real)\n",
913
+ "\n",
914
+ " # Compute gradients\n",
915
+ " trainable_vars = self.trainable_variables\n",
916
+ " gradients = tape.gradient(loss, trainable_vars)\n",
917
+ "\n",
918
+ " # Update weights\n",
919
+ " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
920
+ "\n",
921
+ " # Update metrics (includes the metric that tracks the loss)\n",
922
+ " self.compiled_metrics.update_state(targets_real[:, -1], y_pred[:, -1])\n",
923
+ "\n",
924
+ " # Return a dict mapping metric names to current value\n",
925
+ " return {m.name: m.result() for m in self.metrics}\n",
926
+ " \n",
927
+ " def test_step(self, inputs):\n",
928
+ " inputs, targets = inputs\n",
929
+ " inputs = inputs[:, :-1]\n",
930
+ " targets_inputs = targets[:, :-1]\n",
931
+ " targets_real = targets[:, 1:, :, -1:]\n",
932
+ "\n",
933
+ " # Compute predictions\n",
934
+ " y_pred = self([inputs, targets_inputs], training=False)\n",
935
+ "\n",
936
+ " # Updates the metrics tracking the loss\n",
937
+ " self.compiled_loss(targets_real, y_pred, regularization_losses=self.losses)\n",
938
+ "\n",
939
+ " # Update the metrics\n",
940
+ " self.compiled_metrics.update_state(targets_real[:, -1], y_pred[:, -1])\n",
941
+ "\n",
942
+ " # Return a dict mapping metric names to current value\n",
943
+ " # Note that it will include the loss (tracked in self.metrics)\n",
944
+ " return {m.name: m.result() for m in self.metrics}"
945
+ ]
946
+ },
947
+ {
948
+ "cell_type": "markdown",
949
+ "metadata": {
950
+ "id": "LwEwVCXTRzFx"
951
+ },
952
+ "source": [
953
+ "## LR scheduler"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": null,
959
+ "metadata": {
960
+ "id": "6U3PZiLzRzF1"
961
+ },
962
+ "outputs": [],
963
+ "source": [
964
+ "def cosine_schedule(base_lr, total_steps, warmup_steps):\n",
965
+ " def step_fn(epoch):\n",
966
+ " lr = base_lr\n",
967
+ " epoch += 1\n",
968
+ "\n",
969
+ " progress = (epoch - warmup_steps) / float(total_steps - warmup_steps)\n",
970
+ " progress = tf.clip_by_value(progress, 0.0, 1.0)\n",
971
+ " \n",
972
+ " lr = lr * 0.5 * (1.0 + tf.cos(math.pi * progress))\n",
973
+ "\n",
974
+ " if warmup_steps:\n",
975
+ " lr = lr * tf.minimum(1.0, epoch / warmup_steps)\n",
976
+ "\n",
977
+ " return lr\n",
978
+ "\n",
979
+ " return step_fn"
980
+ ]
981
+ },
982
+ {
983
+ "cell_type": "code",
984
+ "execution_count": null,
985
+ "metadata": {
986
+ "id": "RkGDtc9sRzF6"
987
+ },
988
+ "outputs": [],
989
+ "source": [
990
+ "class PrintLR(Callback):\n",
991
+ " def on_epoch_end(self, epoch, logs=None):\n",
992
+ " wandb.log({\"lr\": self.model.optimizer.lr.numpy()}, commit=False)"
993
+ ]
994
+ },
995
+ {
996
+ "cell_type": "markdown",
997
+ "metadata": {
998
+ "id": "GFi9Y2pORzGA"
999
+ },
1000
+ "source": [
1001
+ "## Daily Dataset"
1002
+ ]
1003
+ },
1004
+ {
1005
+ "cell_type": "code",
1006
+ "execution_count": null,
1007
+ "metadata": {
1008
+ "id": "Dn615CW3I1nL"
1009
+ },
1010
+ "outputs": [],
1011
+ "source": [
1012
+ "df_X = pd.read_csv(\"/content/drive/MyDrive/Solar-Transformer/1984_2022/X_all_daily.csv\")\n",
1013
+ "df_y_daily = pd.read_csv(\"/content/drive/MyDrive/Solar-Transformer/1984_2022/y_all_daily.csv\")"
1014
+ ]
1015
+ },
1016
+ {
1017
+ "cell_type": "code",
1018
+ "execution_count": null,
1019
+ "metadata": {
1020
+ "colab": {
1021
+ "base_uri": "https://localhost:8080/",
1022
+ "height": 89
1023
+ },
1024
+ "id": "1gCRYg_Gzo9U",
1025
+ "outputId": "c73a6bf2-40d0-48de-8356-5be6207c16c3"
1026
+ },
1027
+ "outputs": [],
1028
+ "source": [
1029
+ "plt.hist2d(df_X['WindDirection1'], df_X['WindSpeed1'], bins=(50, 50))\n",
1030
+ "plt.colorbar()\n",
1031
+ "plt.xlabel('Wind Direction [deg]')\n",
1032
+ "plt.ylabel('Wind Velocity [m/s]')\n",
1033
+ "plt.title(\"Wind\")\n",
1034
+ "plt.show()\n",
1035
+ "\n",
1036
+ "plt.hist2d(df_X['WindDirection1'], df_X['WindSpeedMin1'], bins=(50, 50))\n",
1037
+ "plt.colorbar()\n",
1038
+ "plt.xlabel('Wind Direction [deg]')\n",
1039
+ "plt.ylabel('Min Wind Velocity [m/s]')\n",
1040
+ "plt.title(\"Min Wind\")\n",
1041
+ "plt.show()\n",
1042
+ "\n",
1043
+ "plt.hist2d(df_X['WindDirection1'], df_X['WindSpeedMax1'], bins=(50, 50))\n",
1044
+ "plt.colorbar()\n",
1045
+ "plt.xlabel('Wind Direction [deg]')\n",
1046
+ "plt.ylabel('Max Wind Velocity [m/s]')\n",
1047
+ "plt.title(\"Max Wind\")\n",
1048
+ "plt.show()"
1049
+ ]
1050
+ },
1051
+ {
1052
+ "cell_type": "code",
1053
+ "execution_count": null,
1054
+ "metadata": {
1055
+ "id": "TfiNz_9LkNF7"
1056
+ },
1057
+ "outputs": [],
1058
+ "source": [
1059
+ "date_time = pd.to_datetime(df_X.pop('DateTime'), format='%Y-%m-%d')\n",
1060
+ "num_of_patches = df_X['Name'].nunique()\n",
1061
+ "\n",
1062
+ "df_X = df_X.drop(\n",
1063
+ " columns=['Name', 'Latitude', 'Longitude'] +\n",
1064
+ " [c for c in df_X.columns if c[:9] == 'WindSpeed'] +\n",
1065
+ " [c for c in df_X.columns if c[:12] == 'WindSpeedMin'] +\n",
1066
+ " [c for c in df_X.columns if c[:12] == 'WindSpeedMax'] +\n",
1067
+ " [c for c in df_X.columns if c[:13] == 'WindDirection']\n",
1068
+ ")\n",
1069
+ "df_y_daily = df_y_daily.drop(\n",
1070
+ " columns=['DateTime', 'Name', 'Latitude', 'Longitude'] +\n",
1071
+ " [c for c in df_y_daily.columns if c[:9] == 'WindSpeed'] +\n",
1072
+ " [c for c in df_y_daily.columns if c[:12] == 'WindSpeedMin'] +\n",
1073
+ " [c for c in df_y_daily.columns if c[:12] == 'WindSpeedMax'] +\n",
1074
+ " [c for c in df_y_daily.columns if c[:13] == 'WindDirection']\n",
1075
+ ")"
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "execution_count": null,
1081
+ "metadata": {
1082
+ "colab": {
1083
+ "base_uri": "https://localhost:8080/"
1084
+ },
1085
+ "id": "SYWfWTrx-WQy",
1086
+ "outputId": "89eb1644-7546-4a57-829a-304db5954445"
1087
+ },
1088
+ "outputs": [],
1089
+ "source": [
1090
+ "print(df_X.head())\n",
1091
+ "print(df_y_daily.head())"
1092
+ ]
1093
+ },
1094
+ {
1095
+ "cell_type": "code",
1096
+ "execution_count": null,
1097
+ "metadata": {
1098
+ "colab": {
1099
+ "base_uri": "https://localhost:8080/",
1100
+ "height": 89
1101
+ },
1102
+ "id": "OXZLIkjIW1Nn",
1103
+ "outputId": "e8a07e48-ad19-4c85-cfd6-4040a00f95fe"
1104
+ },
1105
+ "outputs": [],
1106
+ "source": [
1107
+ "plt.hist2d(df_X['WindX1'], df_X['WindY1'], bins=(50, 50))\n",
1108
+ "plt.colorbar()\n",
1109
+ "plt.xlabel('Wind X [m/s]')\n",
1110
+ "plt.ylabel('Wind Y [m/s]')\n",
1111
+ "plt.title(\"Wind vector\")\n",
1112
+ "ax = plt.gca()\n",
1113
+ "ax.axis('tight')\n",
1114
+ "plt.show()\n",
1115
+ "\n",
1116
+ "plt.hist2d(df_X['WindXMin1'], df_X['WindYMin1'], bins=(50, 50))\n",
1117
+ "plt.colorbar()\n",
1118
+ "plt.xlabel('Min Wind X [m/s]')\n",
1119
+ "plt.ylabel('Min Wind Y [m/s]')\n",
1120
+ "plt.title(\"Min Wind vector\")\n",
1121
+ "ax = plt.gca()\n",
1122
+ "ax.axis('tight')\n",
1123
+ "plt.show()\n",
1124
+ "\n",
1125
+ "plt.hist2d(df_X['WindXMax1'], df_X['WindYMax1'], bins=(50, 50))\n",
1126
+ "plt.colorbar()\n",
1127
+ "plt.xlabel('Max Wind X [m/s]')\n",
1128
+ "plt.ylabel('Max Wind Y [m/s]')\n",
1129
+ "plt.title(\"Max Wind vector\")\n",
1130
+ "ax = plt.gca()\n",
1131
+ "ax.axis('tight')\n",
1132
+ "plt.show()"
1133
+ ]
1134
+ },
1135
+ {
1136
+ "cell_type": "code",
1137
+ "execution_count": null,
1138
+ "metadata": {
1139
+ "colab": {
1140
+ "base_uri": "https://localhost:8080/",
1141
+ "height": 124
1142
+ },
1143
+ "id": "oXGHnjU62ooH",
1144
+ "outputId": "d0b13b5f-24a6-432e-847d-34edfcb7644f"
1145
+ },
1146
+ "outputs": [],
1147
+ "source": [
1148
+ "x_data = date_time[:(5856 + 5840 + 5840 + 5840):num_of_patches]\n",
1149
+ "\n",
1150
+ "plt.figure(figsize=(16, 4))\n",
1151
+ "plt.plot(x_data, df_X[\"Irradiance1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1152
+ "plt.ylabel('kW-hr/m^2/day')\n",
1153
+ "plt.xlabel(\"Date\")\n",
1154
+ "plt.title(\"Solar irradiance\")\n",
1155
+ "plt.show()\n",
1156
+ "\n",
1157
+ "plt.figure(figsize=(16, 4))\n",
1158
+ "plt.plot(x_data, df_X[\"Temp1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1159
+ "plt.plot(x_data, df_X[\"TempMin1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1160
+ "plt.plot(x_data, df_X[\"TempMax1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1161
+ "plt.ylabel('°C')\n",
1162
+ "plt.xlabel(\"Date\")\n",
1163
+ "plt.title(\"Temperature\")\n",
1164
+ "plt.legend([\"Mean\", \"Min\", \"Max\"])\n",
1165
+ "plt.show()\n",
1166
+ "\n",
1167
+ "plt.figure(figsize=(16, 4))\n",
1168
+ "plt.plot(x_data, df_X[\"Humidity1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1169
+ "plt.ylabel('%')\n",
1170
+ "plt.xlabel(\"Date\")\n",
1171
+ "plt.title(\"Humidity\")\n",
1172
+ "plt.show()\n",
1173
+ "\n",
1174
+ "plt.figure(figsize=(16, 4))\n",
1175
+ "plt.plot(x_data, df_X[\"Pressure1\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1176
+ "plt.ylabel('kPa')\n",
1177
+ "plt.xlabel(\"Date\")\n",
1178
+ "plt.title(\"Pressure\")\n",
1179
+ "plt.show()\n",
1180
+ "\n",
1181
+ "plt.figure(figsize=(16, 4))\n",
1182
+ "plt.plot(x_data, df_X[\"DaySin\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1183
+ "plt.plot(x_data, df_X[\"DayCos\"][:(5856 + 5840 + 5840 + 5840):num_of_patches])\n",
1184
+ "plt.xlabel(\"Date\")\n",
1185
+ "plt.title(\"Time of year signal\")\n",
1186
+ "plt.show()"
1187
+ ]
1188
+ },
1189
+ {
1190
+ "cell_type": "markdown",
1191
+ "metadata": {
1192
+ "id": "Zx4fxfnTSdDE"
1193
+ },
1194
+ "source": [
1195
+ "## Dataset"
1196
+ ]
1197
+ },
1198
+ {
1199
+ "cell_type": "code",
1200
+ "execution_count": null,
1201
+ "metadata": {
1202
+ "id": "pg3PA-58zvaW"
1203
+ },
1204
+ "outputs": [],
1205
+ "source": [
1206
+ "def make_dataset(data, sequence_length, sequence_stride, sampling_rate):\n",
1207
+ " def make_window(data):\n",
1208
+ " dataset = tf.data.Dataset.from_tensor_slices(data)\n",
1209
+ " dataset = dataset.window(sequence_length, shift=sequence_stride, stride=sampling_rate, drop_remainder=True)\n",
1210
+ " dataset = dataset.flat_map(lambda x: x.batch(sequence_length, drop_remainder=True)) \n",
1211
+ " return dataset\n",
1212
+ "\n",
1213
+ " data = np.array(data, dtype=np.float32)\n",
1214
+ " data = np.reshape(data, (-1, num_of_patches, data.shape[-1]))\n",
1215
+ "\n",
1216
+ " # Split the data\n",
1217
+ " # (80%, 10%, 10%)\n",
1218
+ " n = data.shape[0]\n",
1219
+ " n_train = int(n*0.8)\n",
1220
+ " n_val = int(n*0.9)\n",
1221
+ " train_data = data[0:n_train]\n",
1222
+ " val_data = data[n_train:n_val]\n",
1223
+ " test_data = data[n_val:]\n",
1224
+ "\n",
1225
+ " return (\n",
1226
+ " (n_train, make_window(train_data)),\n",
1227
+ " (n_val - n_train, make_window(val_data)),\n",
1228
+ " make_window(test_data)\n",
1229
+ " )\n",
1230
+ "\n",
1231
+ "def merge_dataset(datasets, batch_size, shuffle):\n",
1232
+ " dataset = tf.data.Dataset.zip(datasets)\n",
1233
+ " dataset = dataset.prefetch(tf.data.AUTOTUNE)\n",
1234
+ "\n",
1235
+ " if shuffle:\n",
1236
+ " # Shuffle locally at each iteration\n",
1237
+ " dataset = dataset.shuffle(buffer_size=1000)\n",
1238
+ " dataset = dataset.batch(batch_size)\n",
1239
+ " \n",
1240
+ " return dataset"
1241
+ ]
1242
+ },
1243
+ {
1244
+ "cell_type": "markdown",
1245
+ "metadata": {
1246
+ "id": "s0-fDcMDRzGS"
1247
+ },
1248
+ "source": [
1249
+ "## Training loop"
1250
+ ]
1251
+ },
1252
+ {
1253
+ "cell_type": "code",
1254
+ "execution_count": null,
1255
+ "metadata": {
1256
+ "id": "Ogf0A_urhRCh"
1257
+ },
1258
+ "outputs": [],
1259
+ "source": [
1260
+ "def training_loop(cfg):\n",
1261
+ " # load dataset\n",
1262
+ " (n_train_X, train_X_ds), (n_val_X, val_X_ds), _ = make_dataset(df_X, (cfg.window_size + 1), 1, 1)\n",
1263
+ " (n_train_y, train_y_daily_ds), (n_val_y, val_y_daily_ds), _ = make_dataset(df_y_daily, (cfg.window_size + 1), 1, 1)\n",
1264
+ " assert n_train_X == n_train_y\n",
1265
+ " assert n_val_X == n_val_y\n",
1266
+ "\n",
1267
+ " train_ds = merge_dataset(\n",
1268
+ " (train_X_ds, train_y_daily_ds),\n",
1269
+ " cfg.batch_size,\n",
1270
+ " shuffle=True,\n",
1271
+ " )\n",
1272
+ " val_ds = merge_dataset(\n",
1273
+ " (val_X_ds, val_y_daily_ds),\n",
1274
+ " cfg.batch_size,\n",
1275
+ " shuffle=False,\n",
1276
+ " )\n",
1277
+ "\n",
1278
+ " # Generate new model\n",
1279
+ " daily_model = DailyTransformer(\n",
1280
+ " num_encoder_layers=cfg.num_encoder_layers,\n",
1281
+ " num_decoder_layers=cfg.num_decoder_layers,\n",
1282
+ " embed_dim=cfg.embed_layer_size,\n",
1283
+ " mlp_dim=cfg.fc_layer_size,\n",
1284
+ " num_heads=cfg.num_heads,\n",
1285
+ " num_outputs=1,\n",
1286
+ " dropout_rate=cfg.dropout,\n",
1287
+ " attention_dropout_rate=cfg.attention_dropout,\n",
1288
+ " )\n",
1289
+ "\n",
1290
+ " # adapt on inputs of training dataset - must be before model.compile !!!\n",
1291
+ " daily_model.input_norm_enc.adapt(train_X_ds)\n",
1292
+ " print(daily_model.input_norm_enc.variables)\n",
1293
+ "\n",
1294
+ " # adapt on targets of training dataset - must be before model.compile !!!\n",
1295
+ " daily_model.input_norm_dec1.adapt(train_y_daily_ds)\n",
1296
+ " print(daily_model.input_norm_dec1.variables)\n",
1297
+ " daily_model.input_norm_dec2.adapt(train_y_daily_ds.map(lambda x: x[:, :, -1:]))\n",
1298
+ " print(daily_model.input_norm_dec2.variables)\n",
1299
+ "\n",
1300
+ " # Select optimizer\n",
1301
+ " if cfg.optimizer == \"adam\":\n",
1302
+ " optim = Adam(\n",
1303
+ " beta_1=0.9,\n",
1304
+ " beta_2=0.999,\n",
1305
+ " epsilon=1e-08,\n",
1306
+ " global_clipnorm=cfg.global_clipnorm,\n",
1307
+ " )\n",
1308
+ " elif cfg.optimizer == \"adamw\":\n",
1309
+ " optim = tfa.optimizers.AdamW(\n",
1310
+ " weight_decay=cfg.weight_decay,\n",
1311
+ " beta_1=0.9,\n",
1312
+ " beta_2=0.999,\n",
1313
+ " epsilon=1e-08,\n",
1314
+ " global_clipnorm=cfg.global_clipnorm,\n",
1315
+ " exclude_from_weight_decay=[\"layer_normalization\", \"bias\", \"temporal_position\", \"spatial_position\"],\n",
1316
+ " )\n",
1317
+ " else:\n",
1318
+ " raise ValueError(\"The used optimizer is not in list of available\")\n",
1319
+ "\n",
1320
+ " daily_model.compile(\n",
1321
+ " optimizer=optim,\n",
1322
+ " loss=\"log_cosh\",\n",
1323
+ " metrics=[MeanSquaredError(), RootMeanSquaredError(), MeanAbsoluteError(), RSquare()] \n",
1324
+ " )\n",
1325
+ "\n",
1326
+ " # Train model\n",
1327
+ " daily_model.fit(\n",
1328
+ " train_ds,\n",
1329
+ " epochs=cfg.epochs,\n",
1330
+ " validation_data=val_ds,\n",
1331
+ " callbacks=[\n",
1332
+ " LearningRateScheduler(cosine_schedule(base_lr=cfg.learning_rate, total_steps=cfg.epochs, warmup_steps=cfg.warmup_steps)),\n",
1333
+ " PrintLR(),\n",
1334
+ " WandbCallback(monitor=\"val_mean_squared_error\", mode='min', save_weights_only=True),\n",
1335
+ " EarlyStopping(monitor=\"val_mean_squared_error\", mode='min', min_delta=1e-4, patience=10, restore_best_weights=True, verbose=1),\n",
1336
+ " ],\n",
1337
+ " verbose=1\n",
1338
+ " )\n",
1339
+ "\n",
1340
+ " daily_model.summary()\n",
1341
+ "\n",
1342
+ " patch_similarity_plot(daily_model.pos_embs_0.spatial_position[0, 0])\n",
1343
+ " patch_similarity_plot(daily_model.pos_embs_1.spatial_position[0, 0])\n",
1344
+ " \n",
1345
+ " timestep_similarity_plot(daily_model.pos_embs_0.temporal_position[0, :, 0])\n",
1346
+ " timestep_similarity_plot(daily_model.pos_embs_1.temporal_position[0, :, 0])\n",
1347
+ "\n",
1348
+ " for inputs in val_ds.take(1):\n",
1349
+ " plot_prediction(inputs, daily_model)\n",
1350
+ "\n",
1351
+ " # Resets all state generated by Keras\n",
1352
+ " tf.keras.backend.clear_session()"
1353
+ ]
1354
+ },
1355
+ {
1356
+ "cell_type": "code",
1357
+ "execution_count": null,
1358
+ "metadata": {
1359
+ "id": "RZLpSuL-RzGb"
1360
+ },
1361
+ "outputs": [],
1362
+ "source": [
1363
+ "def run(config=None):\n",
1364
+ " with wandb.init(config=config):\n",
1365
+ " config = wandb.config\n",
1366
+ "\n",
1367
+ " # check rules\n",
1368
+ " if (config.fc_layer_size < config.embed_layer_size):\n",
1369
+ " return\n",
1370
+ " elif (config.warmup_steps >= config.epochs):\n",
1371
+ " return\n",
1372
+ "\n",
1373
+ " training_loop(config)"
1374
+ ]
1375
+ },
1376
+ {
1377
+ "cell_type": "code",
1378
+ "execution_count": null,
1379
+ "metadata": {
1380
+ "colab": {
1381
+ "base_uri": "https://localhost:8080/"
1382
+ },
1383
+ "id": "AXDG0ODuRzGj",
1384
+ "outputId": "aaf73d8c-9c01-4dc7-b3da-df7f9d07a78a"
1385
+ },
1386
+ "outputs": [],
1387
+ "source": [
1388
+ "wandb.agent(sweep_id, run, count=1024)"
1389
+ ]
1390
+ }
1391
+ ],
1392
+ "metadata": {
1393
+ "accelerator": "GPU",
1394
+ "colab": {
1395
+ "collapsed_sections": [
1396
+ "XgDBs9_3l4uD",
1397
+ "oQyRcTjTRzEE",
1398
+ "7O_O6FKlRzE1",
1399
+ "LwEwVCXTRzFx",
1400
+ "GFi9Y2pORzGA",
1401
+ "Zx4fxfnTSdDE"
1402
+ ],
1403
+ "machine_shape": "hm",
1404
+ "name": "Solar_Transformer_2.ipynb",
1405
+ "provenance": []
1406
+ },
1407
+ "gpuClass": "standard",
1408
+ "kernelspec": {
1409
+ "display_name": "Python 3.9.10 ('base')",
1410
+ "language": "python",
1411
+ "name": "python3"
1412
+ },
1413
+ "language_info": {
1414
+ "codemirror_mode": {
1415
+ "name": "ipython",
1416
+ "version": 3
1417
+ },
1418
+ "file_extension": ".py",
1419
+ "mimetype": "text/x-python",
1420
+ "name": "python",
1421
+ "nbconvert_exporter": "python",
1422
+ "pygments_lexer": "ipython3",
1423
+ "version": "3.9.10"
1424
+ },
1425
+ "vscode": {
1426
+ "interpreter": {
1427
+ "hash": "9185113d2128201d66faecd4f34fb34e89a635073a034991399523e584519355"
1428
+ }
1429
+ }
1430
+ },
1431
+ "nbformat": 4,
1432
+ "nbformat_minor": 0
1433
+ }
img/Solar_Transformer.png ADDED
img/output.png ADDED
models/model-best.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d6846b5ec551c96968b7154d6dba026320eaabcbd36457e9bc555896ff22b21
3
+ size 7534528