Note
Click here to download the full example code
OT mapping estimation for domain adaptation¶
This example presents how to use MappingTransport to estimate at the same time both the coupling transport and approximate the transport map with either a linear or a kernelized mapping as introduced in [8].
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, “Mapping estimation for discrete optimal transport”, Neural Information Processing Systems (NIPS), 2016.
# Authors: Remi Flamary <remi.flamary@unice.fr>
# Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
import pottok
Generate data¶
n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1
Xs, ys = ot.datasets.make_data_classif(
'gaussrot', n_source_samples, nz=noise_level)
Xs_new, _ = ot.datasets.make_data_classif(
'gaussrot', n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
'gaussrot', n_target_samples, theta=theta, nz=noise_level)
# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4
Plot data¶
pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
Out:
Text(0.5, 1.0, 'Source and target distributions')
Instantiate the different transport algorithms and fit them¶
MappingTransport with pottok¶
# MappingTransport with circular validation
pottok_circular = pottok.OptimalTransportGridSearch(transport_function = ot.da.LinearTransport,
params=dict(reg=[0.1,1.0]))
pottok_circular.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt)
pottok_circular.fit_circular()
transp_Xs_linear_circular = pottok_circular.predict_transfer(Xs)
transp_Xs_linear_new_circular = pottok_circular.predict_transfer(Xs_new)
# MappingTransport with crossed validation
pottok_crossed = pottok.OptimalTransportGridSearch(transport_function = ot.da.MappingTransport,
params=dict(mu=[2.0,0.1], eta=[10.0,2.0], kernel=["gaussian","linear"], bias=True,
max_iter=20, verbose=True))
pottok_crossed.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt)
pottok_crossed.fit_crossed()
transp_Xs_linear_crossed = pottok_circular.predict_transfer(Xs)
transp_Xs_linear_new_crossed = pottok_circular.predict_transfer(Xs_new)
Out:
Learning Optimal Transport with LinearTransport algorithm.
Xs and Xt are not scaled
[[-0.94866621 1.46257368]
[-1.15585133 0.80387928]
[-1.44897486 0.50824256]
[-0.89772294 0.64410193]
[-0.85301455 0.98022324]
[-0.94418981 0.26755038]
[-1.61582878 1.00357011]
[-1.32245551 0.82225201]
[-1.13578507 1.03995629]
[-0.48931663 0.81860919]
[-0.88322718 0.6809819 ]
[-2.16424689 0.99630282]
[-1.36743175 0.16455512]
[-1.0266606 1.01545117]
[-1.02886639 0.91981715]
[-1.13373572 1.24922293]
[-0.79703491 0.57051263]
[-1.13083705 1.41441421]
[-1.21105601 1.37369702]
[-1.16320176 1.38249079]
[-1.48018003 0.8709035 ]
[-1.25407661 1.59615867]
[-1.45269511 0.62930655]
[-0.97257153 1.11874869]
[-0.36471387 0.93041457]
[-0.90462548 1.00691671]
[-1.48609095 1.21776776]
[-0.50258785 0.85705094]
[-1.20782826 1.24935776]
[-1.00572603 1.21309359]
[-1.28177317 1.27653567]
[-0.84668241 0.6971633 ]
[-1.10717338 1.20095726]
[-1.24310182 1.24046974]
[-1.50121559 1.36751141]
[-0.76427572 0.39769794]
[-1.29018458 0.93225194]
[-1.04320324 1.74127417]
[-1.56752668 0.9866743 ]
[-0.85400661 1.0556128 ]
[-1.35462786 0.5661668 ]
[-1.23504346 1.41783799]
[-1.03546326 0.75833089]
[-0.87146206 1.30826361]
[-1.01440731 1.15925026]
[-1.28500707 1.19893484]
[-0.90328352 0.77242109]
[-1.19154276 0.89239418]
[-0.66865408 1.78295535]
[-0.65610088 1.04591981]
[ 0.97503786 -1.00698932]
[ 0.62240503 -1.54461218]
[ 1.18888923 -1.01364241]
[ 1.29619051 -0.90396368]
[ 1.0084295 -0.91736064]
[ 0.76162348 -1.42534812]
[ 0.85608299 -0.38213206]
[ 1.21389235 -0.70826344]
[ 1.37209172 -1.42861933]
[ 1.21949779 -0.57213643]
[ 0.85953013 -1.46784374]
[ 1.32591679 -1.23822922]
[ 1.05092709 -1.07565728]
[ 0.69183252 -1.09637222]
[ 0.90019676 -1.11476311]
[ 0.53324629 -0.40448687]
[ 1.336815 -1.05308221]
[ 1.47933921 -0.84186139]
[ 1.16183406 -1.22335663]
[ 0.58599988 -0.84367597]
[ 0.56189599 -1.70713211]
[ 1.12602648 -1.14272708]
[ 1.61114867 -0.94689892]
[ 1.17484719 -0.70382369]
[ 1.16628722 -1.12654678]
[ 0.83733867 -0.95591057]
[ 1.27888755 -1.23569109]
[ 1.1958799 -1.22732689]
[ 0.42966247 -0.57663678]
[ 0.82087017 -1.03788072]
[ 0.64614117 -0.42800583]
[ 1.27089072 -1.74443049]
[ 0.79639689 -1.06619423]
[ 0.64627273 -1.17496065]
[ 0.73029126 -0.5363713 ]
[ 0.75422988 -1.25619541]
[ 0.97897433 -1.1813627 ]
[ 0.48730624 -0.98615382]
[ 0.90724552 -1.38523374]
[ 0.97754585 -0.85924557]
[ 1.55754941 -0.70473317]
[ 1.38506649 -1.282643 ]
[ 1.11432937 -1.26829552]
[ 0.54430008 -1.4652264 ]
[ 1.74090842 -0.61422389]
[ 1.20434462 -1.2901593 ]
[ 0.78193011 -1.20035993]
[ 0.95390889 -0.79133333]
[ 1.18749163 -1.22586651]
[ 0.57802052 -0.76958522]]
mean_squared_error is : 5.205524364387433
mean_squared_error is : 11.496210690847514
Best grid is {'reg': 0.1}
Best score is 5.205524364387433
Learning Optimal Transport with MappingTransport algorithm.
Xs and Xt are not scaled
{'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|8.955784e+03|0.000000e+00
1|8.944348e+03|-1.277009e-03
2|8.944173e+03|-1.948179e-05
3|8.944167e+03|-7.459824e-07
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|9.133082e+02|0.000000e+00
1|8.120578e+02|-1.108612e-01
2|7.886155e+02|-2.886778e-02
3|7.759405e+02|-1.607250e-02
4|7.680985e+02|-1.010637e-02
5|7.631109e+02|-6.493508e-03
6|7.599661e+02|-4.121008e-03
7|7.579619e+02|-2.637198e-03
8|7.566935e+02|-1.673420e-03
9|7.558892e+02|-1.062894e-03
10|7.553798e+02|-6.740218e-04
11|7.550540e+02|-4.311906e-04
12|7.548442e+02|-2.779206e-04
13|7.547086e+02|-1.797015e-04
14|7.546188e+02|-1.189777e-04
15|7.545588e+02|-7.951085e-05
16|7.545180e+02|-5.402497e-05
17|7.544889e+02|-3.855979e-05
18|7.544678e+02|-2.791434e-05
19|7.544518e+02|-2.128301e-05
It. |Loss |Delta loss
--------------------------------
20|7.544391e+02|-1.686329e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.96
-------------------------------------------------
{'mu': 2.0, 'eta': 10.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|8.513898e+03|0.000000e+00
1|8.505693e+03|-9.637323e-04
2|8.505281e+03|-4.841365e-05
3|8.505122e+03|-1.872791e-05
4|8.505078e+03|-5.172201e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 10.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|4.714217e+02|0.000000e+00
1|4.393320e+02|-6.807004e-02
2|4.382694e+02|-2.418691e-03
3|4.381401e+02|-2.951365e-04
4|4.380874e+02|-1.202040e-04
5|4.380496e+02|-8.633910e-05
6|4.380186e+02|-7.070721e-05
7|4.379928e+02|-5.897360e-05
8|4.379674e+02|-5.790514e-05
9|4.379471e+02|-4.635792e-05
10|4.379294e+02|-4.053242e-05
11|4.379121e+02|-3.937546e-05
12|4.378975e+02|-3.340889e-05
13|4.378826e+02|-3.391067e-05
14|4.378700e+02|-2.898288e-05
15|4.378583e+02|-2.650786e-05
16|4.378476e+02|-2.452126e-05
17|4.378374e+02|-2.336900e-05
18|4.378281e+02|-2.109940e-05
19|4.378195e+02|-1.976368e-05
It. |Loss |Delta loss
--------------------------------
20|4.378116e+02|-1.807155e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.98
-------------------------------------------------
{'mu': 2.0, 'eta': 2.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|8.603235e+03|0.000000e+00
1|8.596731e+03|-7.560021e-04
2|8.596566e+03|-1.921699e-05
3|8.596525e+03|-4.719687e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 0.99
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 2.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|5.607591e+02|0.000000e+00
1|5.307834e+02|-5.345554e-02
2|5.249185e+02|-1.104954e-02
3|5.226565e+02|-4.309228e-03
4|5.217580e+02|-1.719068e-03
5|5.213766e+02|-7.310240e-04
6|5.212079e+02|-3.235475e-04
7|5.211318e+02|-1.460565e-04
8|5.210910e+02|-7.828221e-05
9|5.210669e+02|-4.627376e-05
10|5.210510e+02|-3.046198e-05
11|5.210383e+02|-2.433469e-05
12|5.210274e+02|-2.092458e-05
13|5.210173e+02|-1.941386e-05
14|5.210083e+02|-1.726729e-05
15|5.209995e+02|-1.697686e-05
16|5.209912e+02|-1.577662e-05
17|5.209833e+02|-1.518414e-05
18|5.209759e+02|-1.432758e-05
19|5.209687e+02|-1.372604e-05
It. |Loss |Delta loss
--------------------------------
20|5.209620e+02|-1.290901e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.96
-------------------------------------------------
{'mu': 2.0, 'eta': 2.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|8.503064e+03|0.000000e+00
1|8.495569e+03|-8.814471e-04
2|8.495267e+03|-3.553040e-05
3|8.495164e+03|-1.210050e-05
4|8.495150e+03|-1.726083e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 0.99
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 2.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It. |Loss |Delta loss
--------------------------------
0|4.605881e+02|0.000000e+00
1|4.342240e+02|-5.723999e-02
2|4.325578e+02|-3.837386e-03
3|4.321406e+02|-9.644834e-04
4|4.319816e+02|-3.678795e-04
5|4.318974e+02|-1.948143e-04
6|4.318405e+02|-1.318959e-04
7|4.317950e+02|-1.052453e-04
8|4.317620e+02|-7.639865e-05
9|4.317358e+02|-6.067535e-05
10|4.317121e+02|-5.487670e-05
11|4.316920e+02|-4.670607e-05
12|4.316734e+02|-4.299021e-05
13|4.316581e+02|-3.556090e-05
14|4.316427e+02|-3.565968e-05
15|4.316307e+02|-2.780602e-05
16|4.316189e+02|-2.733524e-05
17|4.316075e+02|-2.644649e-05
18|4.315973e+02|-2.355514e-05
19|4.315885e+02|-2.030898e-05
It. |Loss |Delta loss
--------------------------------
20|4.315797e+02|-2.040972e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.98
-------------------------------------------------
Best grid is {'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
Best score is 1.0
Best grid is {'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
Total running time of the script: ( 0 minutes 13.771 seconds)