OT mapping estimation for domain adaptation

This example presents how to use MappingTransport to estimate at the same time both the coupling transport and approximate the transport map with either a linear or a kernelized mapping as introduced in [8].

[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, “Mapping estimation for discrete optimal transport”, Neural Information Processing Systems (NIPS), 2016.

# Authors: Remi Flamary <remi.flamary@unice.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
import ot
import pottok

Generate data

n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1

Xs, ys = ot.datasets.make_data_classif(
    'gaussrot', n_source_samples, nz=noise_level)
Xs_new, _ = ot.datasets.make_data_classif(
    'gaussrot', n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
    'gaussrot', n_target_samples, theta=theta, nz=noise_level)

# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4

Plot data

pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
Source and target distributions

Out:

Text(0.5, 1.0, 'Source and target distributions')

Instantiate the different transport algorithms and fit them

MappingTransport with pottok

# MappingTransport with circular validation
pottok_circular = pottok.OptimalTransportGridSearch(transport_function = ot.da.LinearTransport,
                              params=dict(reg=[0.1,1.0]))
pottok_circular.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt)

pottok_circular.fit_circular()

transp_Xs_linear_circular = pottok_circular.predict_transfer(Xs)

transp_Xs_linear_new_circular = pottok_circular.predict_transfer(Xs_new)


# MappingTransport with crossed validation

pottok_crossed = pottok.OptimalTransportGridSearch(transport_function = ot.da.MappingTransport,
                              params=dict(mu=[2.0,0.1], eta=[10.0,2.0], kernel=["gaussian","linear"], bias=True,
                              max_iter=20, verbose=True))

pottok_crossed.preprocessing(Xs=Xs,ys=ys,Xt=Xt,yt=yt)

pottok_crossed.fit_crossed()

transp_Xs_linear_crossed = pottok_circular.predict_transfer(Xs)

transp_Xs_linear_new_crossed = pottok_circular.predict_transfer(Xs_new)

Out:

Learning Optimal Transport with LinearTransport algorithm.
Xs and Xt are not scaled
[[-0.94866621  1.46257368]
 [-1.15585133  0.80387928]
 [-1.44897486  0.50824256]
 [-0.89772294  0.64410193]
 [-0.85301455  0.98022324]
 [-0.94418981  0.26755038]
 [-1.61582878  1.00357011]
 [-1.32245551  0.82225201]
 [-1.13578507  1.03995629]
 [-0.48931663  0.81860919]
 [-0.88322718  0.6809819 ]
 [-2.16424689  0.99630282]
 [-1.36743175  0.16455512]
 [-1.0266606   1.01545117]
 [-1.02886639  0.91981715]
 [-1.13373572  1.24922293]
 [-0.79703491  0.57051263]
 [-1.13083705  1.41441421]
 [-1.21105601  1.37369702]
 [-1.16320176  1.38249079]
 [-1.48018003  0.8709035 ]
 [-1.25407661  1.59615867]
 [-1.45269511  0.62930655]
 [-0.97257153  1.11874869]
 [-0.36471387  0.93041457]
 [-0.90462548  1.00691671]
 [-1.48609095  1.21776776]
 [-0.50258785  0.85705094]
 [-1.20782826  1.24935776]
 [-1.00572603  1.21309359]
 [-1.28177317  1.27653567]
 [-0.84668241  0.6971633 ]
 [-1.10717338  1.20095726]
 [-1.24310182  1.24046974]
 [-1.50121559  1.36751141]
 [-0.76427572  0.39769794]
 [-1.29018458  0.93225194]
 [-1.04320324  1.74127417]
 [-1.56752668  0.9866743 ]
 [-0.85400661  1.0556128 ]
 [-1.35462786  0.5661668 ]
 [-1.23504346  1.41783799]
 [-1.03546326  0.75833089]
 [-0.87146206  1.30826361]
 [-1.01440731  1.15925026]
 [-1.28500707  1.19893484]
 [-0.90328352  0.77242109]
 [-1.19154276  0.89239418]
 [-0.66865408  1.78295535]
 [-0.65610088  1.04591981]
 [ 0.97503786 -1.00698932]
 [ 0.62240503 -1.54461218]
 [ 1.18888923 -1.01364241]
 [ 1.29619051 -0.90396368]
 [ 1.0084295  -0.91736064]
 [ 0.76162348 -1.42534812]
 [ 0.85608299 -0.38213206]
 [ 1.21389235 -0.70826344]
 [ 1.37209172 -1.42861933]
 [ 1.21949779 -0.57213643]
 [ 0.85953013 -1.46784374]
 [ 1.32591679 -1.23822922]
 [ 1.05092709 -1.07565728]
 [ 0.69183252 -1.09637222]
 [ 0.90019676 -1.11476311]
 [ 0.53324629 -0.40448687]
 [ 1.336815   -1.05308221]
 [ 1.47933921 -0.84186139]
 [ 1.16183406 -1.22335663]
 [ 0.58599988 -0.84367597]
 [ 0.56189599 -1.70713211]
 [ 1.12602648 -1.14272708]
 [ 1.61114867 -0.94689892]
 [ 1.17484719 -0.70382369]
 [ 1.16628722 -1.12654678]
 [ 0.83733867 -0.95591057]
 [ 1.27888755 -1.23569109]
 [ 1.1958799  -1.22732689]
 [ 0.42966247 -0.57663678]
 [ 0.82087017 -1.03788072]
 [ 0.64614117 -0.42800583]
 [ 1.27089072 -1.74443049]
 [ 0.79639689 -1.06619423]
 [ 0.64627273 -1.17496065]
 [ 0.73029126 -0.5363713 ]
 [ 0.75422988 -1.25619541]
 [ 0.97897433 -1.1813627 ]
 [ 0.48730624 -0.98615382]
 [ 0.90724552 -1.38523374]
 [ 0.97754585 -0.85924557]
 [ 1.55754941 -0.70473317]
 [ 1.38506649 -1.282643  ]
 [ 1.11432937 -1.26829552]
 [ 0.54430008 -1.4652264 ]
 [ 1.74090842 -0.61422389]
 [ 1.20434462 -1.2901593 ]
 [ 0.78193011 -1.20035993]
 [ 0.95390889 -0.79133333]
 [ 1.18749163 -1.22586651]
 [ 0.57802052 -0.76958522]]
mean_squared_error is : 5.205524364387433
mean_squared_error is : 11.496210690847514
Best grid is {'reg': 0.1}
Best score is 5.205524364387433
Learning Optimal Transport with MappingTransport algorithm.
Xs and Xt are not scaled
{'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|8.955784e+03|0.000000e+00
    1|8.944348e+03|-1.277009e-03
    2|8.944173e+03|-1.948179e-05
    3|8.944167e+03|-7.459824e-07
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|9.133082e+02|0.000000e+00
    1|8.120578e+02|-1.108612e-01
    2|7.886155e+02|-2.886778e-02
    3|7.759405e+02|-1.607250e-02
    4|7.680985e+02|-1.010637e-02
    5|7.631109e+02|-6.493508e-03
    6|7.599661e+02|-4.121008e-03
    7|7.579619e+02|-2.637198e-03
    8|7.566935e+02|-1.673420e-03
    9|7.558892e+02|-1.062894e-03
   10|7.553798e+02|-6.740218e-04
   11|7.550540e+02|-4.311906e-04
   12|7.548442e+02|-2.779206e-04
   13|7.547086e+02|-1.797015e-04
   14|7.546188e+02|-1.189777e-04
   15|7.545588e+02|-7.951085e-05
   16|7.545180e+02|-5.402497e-05
   17|7.544889e+02|-3.855979e-05
   18|7.544678e+02|-2.791434e-05
   19|7.544518e+02|-2.128301e-05
It.  |Loss        |Delta loss
--------------------------------
   20|7.544391e+02|-1.686329e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.96
-------------------------------------------------
{'mu': 2.0, 'eta': 10.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|8.513898e+03|0.000000e+00
    1|8.505693e+03|-9.637323e-04
    2|8.505281e+03|-4.841365e-05
    3|8.505122e+03|-1.872791e-05
    4|8.505078e+03|-5.172201e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 10.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|4.714217e+02|0.000000e+00
    1|4.393320e+02|-6.807004e-02
    2|4.382694e+02|-2.418691e-03
    3|4.381401e+02|-2.951365e-04
    4|4.380874e+02|-1.202040e-04
    5|4.380496e+02|-8.633910e-05
    6|4.380186e+02|-7.070721e-05
    7|4.379928e+02|-5.897360e-05
    8|4.379674e+02|-5.790514e-05
    9|4.379471e+02|-4.635792e-05
   10|4.379294e+02|-4.053242e-05
   11|4.379121e+02|-3.937546e-05
   12|4.378975e+02|-3.340889e-05
   13|4.378826e+02|-3.391067e-05
   14|4.378700e+02|-2.898288e-05
   15|4.378583e+02|-2.650786e-05
   16|4.378476e+02|-2.452126e-05
   17|4.378374e+02|-2.336900e-05
   18|4.378281e+02|-2.109940e-05
   19|4.378195e+02|-1.976368e-05
It.  |Loss        |Delta loss
--------------------------------
   20|4.378116e+02|-1.807155e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.98
-------------------------------------------------
{'mu': 2.0, 'eta': 2.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|8.603235e+03|0.000000e+00
    1|8.596731e+03|-7.560021e-04
    2|8.596566e+03|-1.921699e-05
    3|8.596525e+03|-4.719687e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 0.99
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 2.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|5.607591e+02|0.000000e+00
    1|5.307834e+02|-5.345554e-02
    2|5.249185e+02|-1.104954e-02
    3|5.226565e+02|-4.309228e-03
    4|5.217580e+02|-1.719068e-03
    5|5.213766e+02|-7.310240e-04
    6|5.212079e+02|-3.235475e-04
    7|5.211318e+02|-1.460565e-04
    8|5.210910e+02|-7.828221e-05
    9|5.210669e+02|-4.627376e-05
   10|5.210510e+02|-3.046198e-05
   11|5.210383e+02|-2.433469e-05
   12|5.210274e+02|-2.092458e-05
   13|5.210173e+02|-1.941386e-05
   14|5.210083e+02|-1.726729e-05
   15|5.209995e+02|-1.697686e-05
   16|5.209912e+02|-1.577662e-05
   17|5.209833e+02|-1.518414e-05
   18|5.209759e+02|-1.432758e-05
   19|5.209687e+02|-1.372604e-05
It.  |Loss        |Delta loss
--------------------------------
   20|5.209620e+02|-1.290901e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.96
-------------------------------------------------
{'mu': 2.0, 'eta': 2.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|8.503064e+03|0.000000e+00
    1|8.495569e+03|-8.814471e-04
    2|8.495267e+03|-3.553040e-05
    3|8.495164e+03|-1.210050e-05
    4|8.495150e+03|-1.726083e-06
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 0.99
Best parameter : {'n_estimators': 100}
OA after transport 1.0
-------------------------------------------------
{'mu': 0.1, 'eta': 2.0, 'kernel': 'linear', 'bias': True, 'max_iter': 20, 'verbose': True}
It.  |Loss        |Delta loss
--------------------------------
    0|4.605881e+02|0.000000e+00
    1|4.342240e+02|-5.723999e-02
    2|4.325578e+02|-3.837386e-03
    3|4.321406e+02|-9.644834e-04
    4|4.319816e+02|-3.678795e-04
    5|4.318974e+02|-1.948143e-04
    6|4.318405e+02|-1.318959e-04
    7|4.317950e+02|-1.052453e-04
    8|4.317620e+02|-7.639865e-05
    9|4.317358e+02|-6.067535e-05
   10|4.317121e+02|-5.487670e-05
   11|4.316920e+02|-4.670607e-05
   12|4.316734e+02|-4.299021e-05
   13|4.316581e+02|-3.556090e-05
   14|4.316427e+02|-3.565968e-05
   15|4.316307e+02|-2.780602e-05
   16|4.316189e+02|-2.733524e-05
   17|4.316075e+02|-2.644649e-05
   18|4.315973e+02|-2.355514e-05
   19|4.315885e+02|-2.030898e-05
It.  |Loss        |Delta loss
--------------------------------
   20|4.315797e+02|-2.040972e-05
/home/docs/checkouts/readthedocs.org/user_builds/pottok/conda/latest/lib/python3.7/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass groups=None as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)
Crossed validation OA : 1.0
Best parameter : {'n_estimators': 100}
OA after transport 0.98
-------------------------------------------------
Best grid is {'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}
Best score is 1.0
Best grid is {'mu': 2.0, 'eta': 10.0, 'kernel': 'gaussian', 'bias': True, 'max_iter': 20, 'verbose': True}

Total running time of the script: ( 0 minutes 13.771 seconds)

Gallery generated by Sphinx-Gallery