Note
Click here to download the full example code
Pottok for image color adaptation with labels - RasterOptimalTransport¶
Using sinkhorn L1l2
import numpy as np
import matplotlib.pylab as pl
import ot
import pottok
from sklearn.preprocessing import StandardScaler,MinMaxScaler # centrer-réduire
source_image,source_vector,target_image,target_vector = pottok.datasets.load_pottoks(return_only_path = True)
brown_pottok,black_pottok = pottok.datasets.load_pottoks(return_X_y=False)
brown_pottok = brown_pottok/255
black_pottok = black_pottok/255
label = 'level'
Optimal transport with SinkhornL1l2 with circular gridsearch¶
raster_transport_circular = pottok.RasterOptimalTransport(transport_function=ot.da.SinkhornL1l2Transport,
params=dict(reg_e=[1e-1,1e-0], reg_cl=[1e-1]))
raster_transport_circular.preprocessing(image_source = source_image,
image_target = target_image,
vector_source = source_vector,
vector_target = target_vector,
label_source = label,
label_target = label,
scaler = MinMaxScaler)
raster_transport_circular.fit_circular()
# Best grid is {'reg_e': 1.0, 'reg_cl': 1.0}
Out:
source and target are scaled
Image is scaled
[[0.91372549 0.87058824 0.68235294]
[0.92156863 0.8745098 0.68627451]
[0.92156863 0.87843137 0.69019608]
...
[0.74901961 0.74509804 0.42352941]
[0.74901961 0.74509804 0.44313725]
[0.75686275 0.75294118 0.43529412]]
mean_squared_error is : 0.026268089908466855
mean_squared_error is : 0.02005365344925465
Best grid is {'reg_e': 1.0, 'reg_cl': 0.1}
Best score is 0.02005365344925465
<ot.da.SinkhornL1l2Transport object at 0x7fe99523eed0>
Plot images¶
Xt_transp_unscaled, Xt_transp_scaled = raster_transport_circular.predict_transfer(raster_transport_circular.source)
pl.figure(1, figsize=(10,8))
pl.subplot(2, 2, 1)
pl.imshow(brown_pottok)
pl.axis('off')
pl.title('Brown pottok (Source)')
pl.subplot(2, 2, 3)
pl.imshow(black_pottok)
pl.axis('off')
pl.title('Black pottok (Target)')
pl.subplot(2, 2, 4)
pl.imshow(Xt_transp_unscaled.reshape(*brown_pottok.shape)/255)
pl.axis('off')
pl.title('SinkhornL1l2 (Source to Target with labels)')
pl.show()
# ##############################################################################
# # Optimal transport with SinkhornL1l2 with crossed gridsearch
# # --------------------------------------------------------------
raster_transport_crossed = pottok.RasterOptimalTransport(transport_function=ot.da.SinkhornL1l2Transport,
params=dict(reg_e=[1e-1,1e-0], reg_cl=[1e-1]))
raster_transport_crossed.preprocessing(image_source = source_image,
image_target = target_image,
vector_source = source_vector,
vector_target = target_vector,
label_source = label,
label_target = label,
scaler = MinMaxScaler)
raster_transport_crossed.fit_crossed()
# Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}
Out:
source and target are scaled
Image is scaled
{'reg_e': 0.1, 'reg_cl': 0.1}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.9992156862745099
-------------------------------------------------
{'reg_e': 1.0, 'reg_cl': 0.1}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.9184313725490196
-------------------------------------------------
Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}
Best score is 0.9992156862745099
Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}
<ot.da.SinkhornL1l2Transport object at 0x7fe9951dd850>
Plot images¶
Xt_transp_unscaled, Xt_transp_scaled = raster_transport_crossed.predict_transfer(raster_transport_crossed.source)
pl.figure(2, figsize=(10,8))
pl.subplot(2, 2, 1)
pl.imshow(brown_pottok)
pl.axis('off')
pl.title('Brown pottok (Source)')
pl.subplot(2, 2, 3)
pl.imshow(black_pottok)
pl.axis('off')
pl.title('Black pottok (Target)')
pl.subplot(2, 2, 4)
pl.imshow(Xt_transp_scaled.reshape(*brown_pottok.shape))
pl.axis('off')
pl.title('SinkhornL1l2 (Source to Target with labels)')
pl.show()
Out:
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Total running time of the script: ( 2 minutes 14.468 seconds)