Pottok for image color adaptation with labels - OptimalTransportGridSearch

Using sinkhorn L1l2

import numpy as np
import matplotlib.pylab as pl
import ot
import pottok

Load pottoks

# Loading X and y
Xs,ys,Xt,yt = pottok.datasets.load_pottoks()

Xs = Xs/255
Xt = Xt/255

# Loading images array
brown_pottok,black_pottok = pottok.datasets.load_pottoks(return_X_y=False)
brown_pottok = brown_pottok/255
black_pottok = black_pottok/255

Optimal transport with SinkhornL1l2 with circular gridsearch

gridsearch_transport_circular = pottok.OptimalTransportGridSearch(transport_function=ot.da.SinkhornL1l2Transport,
                                        params=dict(reg_e=[1e-1,1e-0], reg_cl=[1e-1,1e-0]))
gridsearch_transport_circular.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt,scaler=False)
gridsearch_transport_circular.fit_circular()


# Best grid is {'reg_e': 1.0, 'reg_cl': 0.1}

Out:

Learning Optimal Transport with SinkhornL1l2Transport algorithm.
Xs and Xt are not scaled
[[0.91372549 0.87058824 0.68235294]
 [0.92156863 0.8745098  0.68627451]
 [0.92156863 0.87843137 0.69019608]
 ...
 [0.74901961 0.74509804 0.42352941]
 [0.74901961 0.74509804 0.44313725]
 [0.75686275 0.75294118 0.43529412]]
mean_squared_error is : 0.026268089908466827
mean_squared_error is : 0.02005365344925465
mean_squared_error is : 0.02499567991893459
mean_squared_error is : 0.02009413283983057
Best grid is {'reg_e': 1.0, 'reg_cl': 0.1}
Best score is 0.02005365344925465

<ot.da.SinkhornL1l2Transport object at 0x7fe995c6cb10>

Plot images

brown_pottok_transp_circular = gridsearch_transport_circular.predict_transfer(brown_pottok.reshape(-1,3))
pl.figure(1, figsize=(10,8))

pl.subplot(2, 2, 1)
pl.imshow(brown_pottok)
pl.axis('off')
pl.title('Brown pottok (Source)')

pl.subplot(2, 2, 3)
pl.imshow(black_pottok)
pl.axis('off')
pl.title('Black pottok (Target)')

pl.subplot(2, 2, 4)
pl.imshow(brown_pottok_transp_circular.reshape(*brown_pottok.shape))
pl.axis('off')
pl.title('SinkhornL1l2 (Source to Target with labels)')

pl.show()
Brown pottok (Source), Black pottok (Target), SinkhornL1l2 (Source to Target with labels)

Optimal transport with SinkhornL1l2 with crossed gridsearch

gridsearch_transport_crossed = pottok.OptimalTransportGridSearch(transport_function=ot.da.SinkhornL1l2Transport,
                                        params=dict(reg_e=[1e-1,1e-0], reg_cl=[1e-1,1e-0]))
gridsearch_transport_crossed.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt,scaler=False)
gridsearch_transport_crossed.fit_crossed()

# Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}

Out:

Learning Optimal Transport with SinkhornL1l2Transport algorithm.
Xs and Xt are not scaled
{'reg_e': 0.1, 'reg_cl': 0.1}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.9992156862745099
-------------------------------------------------
{'reg_e': 1.0, 'reg_cl': 0.1}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.9184313725490196
-------------------------------------------------
{'reg_e': 0.1, 'reg_cl': 1.0}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.9992156862745099
-------------------------------------------------
{'reg_e': 1.0, 'reg_cl': 1.0}
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.923921568627451
-------------------------------------------------
Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}
Best score is 0.9992156862745099
Best grid is {'reg_e': 0.1, 'reg_cl': 0.1}

<ot.da.SinkhornL1l2Transport object at 0x7fe995c6cc90>

Plot images

brown_pottok_transp_crossed = gridsearch_transport_crossed.predict_transfer(brown_pottok.reshape(-1,3))
pl.figure(2, figsize=(10,8))

pl.subplot(2, 2, 1)
pl.imshow(brown_pottok)
pl.axis('off')
pl.title('Brown pottok (Source)')

pl.subplot(2, 2, 3)
pl.imshow(black_pottok)
pl.axis('off')
pl.title('Black pottok (Target)')

pl.subplot(2, 2, 4)
pl.imshow(brown_pottok_transp_crossed.reshape(*brown_pottok.shape))
pl.axis('off')
pl.title('SinkhornL1l2 (Source to Target with labels)')

pl.show()
Brown pottok (Source), Black pottok (Target), SinkhornL1l2 (Source to Target with labels)

Out:

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Comparison with pot

ot_mapping_linear_circular = ot.da.SinkhornL1l2Transport(
    reg_e=1.0, reg_cl=0.1,verbose=True)
ot_mapping_linear_circular.fit(Xs=Xs, ys = ys, Xt=Xt, yt = yt)
brown_pottok_transp_pot_circular = ot_mapping_linear_circular.transform(brown_pottok.reshape(-1,3))

if (brown_pottok_transp_pot_circular == brown_pottok_transp_circular).all() :
    print ("POT and Pottok give same transformation - circular")
else :
    print ("ERROR : POT and Pottok do not give same transformation")

ot_mapping_linear_crossed = ot.da.SinkhornL1l2Transport(
    reg_e=0.1, reg_cl=0.1,verbose=True)
ot_mapping_linear_crossed.fit(Xs=Xs, ys = ys, Xt=Xt, yt = yt)
brown_pottok_transp_pot_crossed = ot_mapping_linear_crossed.transform(brown_pottok.reshape(-1,3))


if (brown_pottok_transp_pot_crossed == brown_pottok_transp_crossed).all() :
    print ("POT and Pottok give same transformation - crossed")
else :
    print ("ERROR : POT and Pottok do not give same transformation")

Out:

It.  |Loss        |Relative loss|Absolute loss
------------------------------------------------
    0|-7.604479e+00|0.000000e+00|0.000000e+00
    1|-1.255746e+01|3.944252e-01|4.952977e+00
    2|-1.260090e+01|3.447456e-03|4.344104e-02
    3|-1.260128e+01|3.065976e-05|3.863523e-04
    4|-1.260129e+01|2.685641e-07|3.384253e-06
    5|-1.260129e+01|2.305315e-09|2.904994e-08
POT and Pottok give same transformation - circular
It.  |Loss        |Relative loss|Absolute loss
------------------------------------------------
    0|5.772704e+00|0.000000e+00|0.000000e+00
    1|-1.198499e+00|5.816611e+00|6.971204e+00
    2|-1.266761e+00|5.388650e-02|6.826130e-02
    3|-1.267430e+00|5.283162e-04|6.696039e-04
    4|-1.267437e+00|5.180816e-06|6.566356e-06
    5|-1.267437e+00|5.078521e-08|6.436705e-08
    6|-1.267437e+00|4.976234e-10|6.307062e-10
POT and Pottok give same transformation - crossed

Total running time of the script: ( 5 minutes 18.089 seconds)

Gallery generated by Sphinx-Gallery