Commit ae037e43 authored by Julia Wagemann's avatar Julia Wagemann
Browse files

LTPy v0.9 release

parent 1cd1e0bd
This diff is collapsed.
%% Cell type:code id:61bc4bcb tags:
``` python
from matplotlib import pyplot as plt
import matplotlib.colors
from matplotlib.colors import LogNorm
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import matplotlib.cm as cm
import warnings
warnings.simplefilter(action = "ignore", category = RuntimeWarning)
warnings.simplefilter(action = "ignore", category = FutureWarning)
```
%% Cell type:markdown id:64f71561 tags:
### <a id='visualize_pcolormesh'></a>`visualize_pcolormesh`
%% Cell type:code id:0f5fd817 tags:
``` python
def visualize_pcolormesh(data_array, longitude, latitude, projection, color_scale, unit, long_name, vmin, vmax,
set_global=True, lonmin=-180, lonmax=180, latmin=-90, latmax=90):
"""
Visualizes a xarray.DataArray with matplotlib's pcolormesh function.
Parameters:
data_array(xarray.DataArray): xarray.DataArray holding the data values
longitude(xarray.DataArray): xarray.DataArray holding the longitude values
latitude(xarray.DataArray): xarray.DataArray holding the latitude values
projection(str): a projection provided by the cartopy library, e.g. ccrs.PlateCarree()
color_scale(str): string taken from matplotlib's color ramp reference
unit(str): the unit of the parameter, taken from the NetCDF file if possible
long_name(str): long name of the parameter, taken from the NetCDF file if possible
vmin(int): minimum number on visualisation legend
vmax(int): maximum number on visualisation legend
set_global(boolean): optional kwarg, default is True
lonmin,lonmax,latmin,latmax(float): optional kwarg, set geographic extent is set_global kwarg is set to
False
"""
fig=plt.figure(figsize=(20, 10))
ax = plt.axes(projection=projection)
img = plt.pcolormesh(longitude, latitude, data_array,
cmap=plt.get_cmap(color_scale), transform=ccrs.PlateCarree(),
vmin=vmin,
vmax=vmax,
shading='auto')
ax.add_feature(cfeature.BORDERS, edgecolor='black', linewidth=1)
ax.add_feature(cfeature.COASTLINE, edgecolor='black', linewidth=1)
if (projection==ccrs.PlateCarree()):
ax.set_extent([lonmin, lonmax, latmin, latmax], projection)
gl = ax.gridlines(draw_labels=True, linestyle='--')
gl.top_labels=False
gl.right_labels=False
gl.xformatter=LONGITUDE_FORMATTER
gl.yformatter=LATITUDE_FORMATTER
gl.xlabel_style={'size':14}
gl.ylabel_style={'size':14}
if(set_global):
ax.set_global()
ax.gridlines()
cbar = fig.colorbar(img, ax=ax, orientation='horizontal', fraction=0.04, pad=0.1)
cbar.set_label(unit, fontsize=16)
cbar.ax.tick_params(labelsize=14)
ax.set_title(long_name, fontsize=20, pad=20.0)
# plt.show()
return fig, ax
```
%% Cell type:code id:b52ed025 tags:
``` python
```
%% Cell type:code id:bfd68233 tags:
``` python
```
%% Cell type:code id:ec4a379c tags:
``` python
from fastai.data.all import *
from fastai.vision.all import *
from fastai.vision.core import *
from fastai.vision.data import *
import pdb
import utils
from utils.UGeop import GRaster # library made by me (Maximilien Houël) for geoprocessings
R = GRaster.Raster()
```
%% Cell type:code id:f1d3c9d0 tags:
``` python
import torch
print(torch.cuda.get_device_name(0))
```
%% Cell type:markdown id:97d57b75 tags:
# -Set up the environment-
%% Cell type:markdown id:4c5ba873 tags:
For the inference the model saved needs to be reloaded in the same environment than for the training. The machine needs to know what were the elements used when training, from loss function to metrics and of course the dataloading. All the functions developed for the training will be integrated in python files to avoid an overload of the jupyter.
%% Cell type:code id:b2aa1692 tags:
``` python
from utils.dataloader import createDLS, PProcess
def getLab(path): # function to retrieve the target based on the input (will differ based on how you created the dataset)
path = str(path).replace('CAMS', 'S5P')
return Path(path)
```
%% Cell type:code id:3ae12012 tags:
``` python
from utils.loss import VGG, gram_matrix, FeatureLoss
base_loss = F.l1_loss
VGG1 = VGG(1) # import a vgg model for 1 channel size
vgg_m = VGG1.cuda().eval()
requires_grad(vgg_m)
blocks = [i-1 for i,o in enumerate(vgg_m.children()) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks] # layers numbers before the max pooling layer
loss = FeatureLoss(vgg_m, blocks[2:5], [5,12,2])
#creating the loss function that will be set up when loading the opt state of the model
```
%% Cell type:code id:removable-teach tags:
``` python
from utils.metrics import PSNR, ssim
```
%% Cell type:code id:accessible-elder tags:
``` python
def standardization(array, mean, std) : # array / mean / standard deviation
return (array-mean)/std # compute the standardized image
def normalization(array, minimum, maximum): # array to normalized / minimum and maximum to consider for the normalization (what range)
return (array - minimum)/(maximum - minimum) # compute the normalized image
def compareList(l1, l2): # function to return boolean list to compare two lists
return [i==j for i, j in zip(l1, l2)]
```
%% Cell type:markdown id:7e0b881e tags:
Setting up the pipeline for preparing and analyzing the data to produce the output
%% Cell type:markdown id:bebe08c9 tags:
# -Initialization-
%% Cell type:markdown id:0d648078 tags:
The inference function created is reapplying all the steps of preprocessing :
- Normalization
- Tilling
And then proceed to the prediction over all tiles.
After the prediction to avoid edges effects on the tiles, I set up a padding over the results to clear those effects and keep only the element that really converged in the middle.
Afterward the function retrieve all the predicted tiles and merge it into one single image based on the geographic references of the input.
%% Cell type:code id:renewable-timer tags:
``` python
from utils.inference import SISR
```
%% Cell type:markdown id:c56845e7 tags:
Retrive the range of values of your input and target to apply the same normalization for the inference.
%% Cell type:code id:raising-cowboy tags:
``` python
inpMIN = #input minimum
inpMAX = #input maximum
tarMIN = #target minimum
tarMAX = #target maximum
```
%% Cell type:code id:0fa95d29 tags:
``` python
img = '' # location of the image to apply the inference
dst = '' # location to save the result obtained during the inference
model = '' # location of the model trained (the .pkl file)
scale = # define the scaling factor, the scaling factor will help to retrieve the geographic informations of the input
tile_size = # size of the input patches here (for tiling the input) then with the scaling factor it will get the target size it wants
learn = load_learner(model, cpu=True) # load the model to get ready to run (here you can decide if you want to run it on cpu)
```
%% Cell type:code id:be796277 tags:
``` python
result = SISR(img, dst, learn, tile_size, scale, inpMIN, inpMAX, tarMIN, tarMAX)
```
%% Cell type:raw id:0d4b6467 tags:
SISR(image location to perform the prediction,
destination location of the result,
learner where the model is loaded,
size of input tiles,
scaling factor,
input minimum,
input maximum,
target minimum,
target maximum)
%% Cell type:markdown id:0c8f5d0a tags:
# -Vizualization-
%% Cell type:markdown id:2471be2d tags:
Plotting the input image
%% Cell type:code id:c4f19fcf tags:
``` python
data = R.asArray(img)
norm = normalization(data, inpMIN, inpMAX)
print(data.shape)
plt.figure(figsize=(20,20))
plt.imshow(norm)
```
%% Cell type:markdown id:4f28479f tags:
Plotting the bicubic baseline image
%% Cell type:code id:iraqi-guard tags:
``` python
dstB = '{}'.format(dst.replace('test', 'bicubic{}'.format(scale)))
geoInfo = R.getGeoInfo(img)
R.resizing(img, dstB, width=int(geoInfo['Width']*scale), height=int(geoInfo['Height']*scale), resampleAlg = 3)
plt.figure(figsize=(20,20))
plt.imshow(R.asArray(dstB))
print(R.asArray(dstB).shape)
```
%% Cell type:markdown id:7ce015ee tags:
Plotting the result image
%% Cell type:code id:be32bf9b tags:
``` python
print(result.shape)
plt.figure(figsize=(20,20))
plt.imshow(result)
```
UGeop @ 18e32753
Subproject commit 18e3275362e23cef1cd99adb3d6160fc538f9916
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment