""" Tests for gromov._semirelaxed.py """

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License

import numpy as np
import pytest

import ot
from ot.backend import torch


def test_semirelaxed_gromov(nx):
    rng = np.random.RandomState(0)
    # unbalanced proportions
    list_n = [30, 15]
    nt = 2
    ns = np.sum(list_n)
    # create directed sbm with C2 as connectivity matrix
    C1 = np.zeros((ns, ns), dtype=np.float64)
    C2 = np.array([[0.8, 0.05],
                   [0.05, 1.]], dtype=np.float64)
    for i in range(nt):
        for j in range(nt):
            ni, nj = list_n[i], list_n[j]
            xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j])
            C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
    p = ot.unif(ns, type_as=C1)
    q0 = ot.unif(C2.shape[0], type_as=C1)
    G0 = p[:, None] * q0[None, :]
    # asymmetric
    C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)

    for loss_fun in ['square_loss', 'kl_loss']:
        G, log = ot.gromov.semirelaxed_gromov_wasserstein(
            C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
        Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(
            C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True,
            G0=None, alpha_min=0., alpha_max=1.)

        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
        np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
        np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)

        srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(
            C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
        srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(
            C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)

        G = log2['T']
        Gb = nx.to_numpy(logb2['T'])
        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04)  # cf convergence gromov

        np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
        np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)

    # symmetric
    C1 = 0.5 * (C1 + C1.T)
    C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)

    G, log = ot.gromov.semirelaxed_gromov_wasserstein(
        C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
    Gb = ot.gromov.semirelaxed_gromov_wasserstein(
        C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)

    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

    srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(
        C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
    srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(
        C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)

    srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)

    G = log2['T']
    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
    np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)

    np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
    np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
    np.testing.assert_allclose(srgw, srgw_, atol=1e-07)


def test_semirelaxed_gromov2_gradients():
    n_samples = 50  # nb samples

    mu_s = np.array([0, 0])
    cov_s = np.array([[1, 0], [0, 1]])

    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)

    xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)

    p = ot.unif(n_samples)

    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    if torch:

        devices = [torch.device("cpu")]
        if torch.cuda.is_available():
            devices.append(torch.device("cuda"))
        for device in devices:
            for loss_fun in ['square_loss', 'kl_loss']:
                # semirelaxed solvers do not support gradients over masses yet.
                p1 = torch.tensor(p, requires_grad=False, device=device)
                C11 = torch.tensor(C1, requires_grad=True, device=device)
                C12 = torch.tensor(C2, requires_grad=True, device=device)

                val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1, loss_fun=loss_fun)

                val.backward()

                assert val.device == p1.device
                assert p1.grad is None
                assert C11.shape == C11.grad.shape
                assert C12.shape == C12.grad.shape


def test_srgw_helper_backend(nx):
    n_samples = 20  # nb samples

    mu = np.array([0, 0])
    cov = np.array([[1, 0], [0, 1]])

    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
    xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)

    p = ot.unif(n_samples)
    q = ot.unif(n_samples)
    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    for loss_fun in ['square_loss', 'kl_loss']:
        C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
        Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun, armijo=False, symmetric=True, G0=None, log=True)

        # calls with nx=None
        constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun)
        ones_pb = nx.ones(pb.shape[0], type_as=pb)

        def f(G):
            qG = nx.sum(G, 0)
            marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
            return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)

        def df(G):
            qG = nx.sum(G, 0)
            marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
            return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)

        def line_search(cost, G, deltaG, Mi, cost_G):
            return ot.gromov.solve_semirelaxed_gromov_linesearch(
                G, deltaG, cost_G, hC1b, hC2b, ones_pb, 0., 1., fC2t=fC2tb, nx=None)
        # feed the precomputed local optimum Gb to semirelaxed_cg
        res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
        # check constraints
        np.testing.assert_allclose(res, Gb, atol=1e-06)


@pytest.mark.parametrize('loss_fun', [
    'square_loss', 'kl_loss',
    pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)),
])
def test_gw_semirelaxed_helper_validation(loss_fun):
    n_samples = 20  # nb samples
    mu = np.array([0, 0])
    cov = np.array([[1, 0], [0, 1]])
    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
    xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
    p = ot.unif(n_samples)
    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)
    ot.gromov.init_matrix_semirelaxed(C1, C2, p, loss_fun=loss_fun)


def test_semirelaxed_fgw(nx):
    rng = np.random.RandomState(0)
    list_n = [16, 8]
    nt = 2
    ns = 24
    # create directed sbm with C2 as connectivity matrix
    C1 = np.zeros((ns, ns))
    C2 = np.array([[0.7, 0.05],
                   [0.05, 0.9]])
    for i in range(nt):
        for j in range(nt):
            ni, nj = list_n[i], list_n[j]
            xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j])
            C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
    F1 = np.zeros((ns, 1))
    F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1))
    F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1))
    F2 = np.zeros((2, 1))
    F2[1, :] = 1.
    M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)

    p = ot.unif(ns)
    q0 = ot.unif(C2.shape[0])
    G0 = p[:, None] * q0[None, :]

    # asymmetric
    Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
    G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
    Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)

    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

    srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
    srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)

    G = log2['T']
    Gb = nx.to_numpy(logb2['T'])
    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04)  # cf convergence gromov

    np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
    np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)

    # symmetric
    for loss_fun in ['square_loss', 'kl_loss']:
        C1 = 0.5 * (C1 + C1.T)
        Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)

        G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None)
        Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b)

        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

        srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0)
        srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None)

        srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0)

        G = log2['T']
        Gb = nx.to_numpy(logb2['T'])
        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04)  # cf convergence gromov

        np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
        np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
        np.testing.assert_allclose(srgw, srgw_, atol=1e-07)


def test_semirelaxed_fgw2_gradients():
    n_samples = 20  # nb samples

    mu_s = np.array([0, 0])
    cov_s = np.array([[1, 0], [0, 1]])

    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)

    xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)

    p = ot.unif(n_samples)

    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)
    M = ot.dist(xs, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    if torch:

        devices = [torch.device("cpu")]
        if torch.cuda.is_available():
            devices.append(torch.device("cuda"))
        for device in devices:
            # semirelaxed solvers do not support gradients over masses yet.
            for loss_fun in ['square_loss', 'kl_loss']:
                p1 = torch.tensor(p, requires_grad=False, device=device)
                C11 = torch.tensor(C1, requires_grad=True, device=device)
                C12 = torch.tensor(C2, requires_grad=True, device=device)
                M1 = torch.tensor(M, requires_grad=True, device=device)

                val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun)

                val.backward()

                assert val.device == p1.device
                assert p1.grad is None
                assert C11.shape == C11.grad.shape
                assert C12.shape == C12.grad.shape
                assert M1.shape == M1.grad.shape

                # full gradients with alpha
                p1 = torch.tensor(p, requires_grad=False, device=device)
                C11 = torch.tensor(C1, requires_grad=True, device=device)
                C12 = torch.tensor(C2, requires_grad=True, device=device)
                M1 = torch.tensor(M, requires_grad=True, device=device)
                alpha = torch.tensor(0.5, requires_grad=True, device=device)

                val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, loss_fun=loss_fun, alpha=alpha)

                val.backward()

                assert val.device == p1.device
                assert p1.grad is None
                assert C11.shape == C11.grad.shape
                assert C12.shape == C12.grad.shape
                assert alpha.shape == alpha.grad.shape


def test_srfgw_helper_backend(nx):
    n_samples = 20  # nb samples

    mu = np.array([0, 0])
    cov = np.array([[1, 0], [0, 1]])

    rng = np.random.RandomState(42)
    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
    ys = rng.randn(xs.shape[0], 2)
    xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
    yt = rng.randn(xt.shape[0], 2)

    p = ot.unif(n_samples)
    q = ot.unif(n_samples)
    G0 = p[:, None] * q[None, :]

    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    M = ot.dist(ys, yt)
    M /= M.max()

    Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
    alpha = 0.5
    Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)

    # calls with nx=None
    constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
    ones_pb = nx.ones(pb.shape[0], type_as=pb)

    def f(G):
        qG = nx.sum(G, 0)
        marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
        return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)

    def df(G):
        qG = nx.sum(G, 0)
        marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
        return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)

    def line_search(cost, G, deltaG, Mi, cost_G):
        return ot.gromov.solve_semirelaxed_gromov_linesearch(
            G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None)
    # feed the precomputed local optimum Gb to semirelaxed_cg
    res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
    # check constraints
    np.testing.assert_allclose(res, Gb, atol=1e-06)


def test_entropic_semirelaxed_gromov(nx):
    # unbalanced proportions
    list_n = [30, 15]
    nt = 2
    ns = np.sum(list_n)
    # create directed sbm with C2 as connectivity matrix
    C1 = np.zeros((ns, ns), dtype=np.float64)
    C2 = np.array([[0.8, 0.05],
                   [0.05, 1.]], dtype=np.float64)
    rng = np.random.RandomState(0)
    for i in range(nt):
        for j in range(nt):
            ni, nj = list_n[i], list_n[j]
            xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j])
            C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
    p = ot.unif(ns, type_as=C1)
    q0 = ot.unif(C2.shape[0], type_as=C1)
    G0 = p[:, None] * q0[None, :]
    # asymmetric
    C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
    epsilon = 0.1
    for loss_fun in ['square_loss', 'kl_loss']:
        G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=G0)
        Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=None)

        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
        np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
        np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)

        srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun=loss_fun, epsilon=epsilon, symmetric=False, log=True, G0=G0)
        srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun=loss_fun, epsilon=epsilon, symmetric=None, log=True, G0=None)

        G = log2['T']
        Gb = nx.to_numpy(logb2['T'])
        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04)  # cf convergence gromov

        np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
        np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)

    # symmetric
    C1 = 0.5 * (C1 + C1.T)
    C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)

    G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None)
    Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b)

    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

    srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0)
    srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None)

    srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0)

    G = log2['T']
    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
    np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)

    np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
    np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
    np.testing.assert_allclose(srgw, srgw_, atol=1e-07)


@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_semirelaxed_gromov_dtype_device(nx):
    # setup
    n_samples = 5  # nb samples

    mu_s = np.array([0, 0])
    cov_s = np.array([[1, 0], [0, 1]])

    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)

    xt = xs[::-1].copy()

    p = ot.unif(n_samples)

    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    for tp in nx.__type_list__:

        print(nx.dtype_device(tp))
        for loss_fun in ['square_loss', 'kl_loss']:
            C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp)

            Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(
                C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True
            )
            gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(
                C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True
            )

            nx.assert_same_dtype_device(C1b, Gb)
            nx.assert_same_dtype_device(C1b, gw_valb)


def test_entropic_semirelaxed_fgw(nx):
    rng = np.random.RandomState(0)
    list_n = [16, 8]
    nt = 2
    ns = 24
    # create directed sbm with C2 as connectivity matrix
    C1 = np.zeros((ns, ns))
    C2 = np.array([[0.7, 0.05],
                   [0.05, 0.9]])
    for i in range(nt):
        for j in range(nt):
            ni, nj = list_n[i], list_n[j]
            xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j])
            C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
    F1 = np.zeros((ns, 1))
    F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1))
    F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1))
    F2 = np.zeros((2, 1))
    F2[1, :] = 1.
    M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)

    p = ot.unif(ns)
    q0 = ot.unif(C2.shape[0])
    G0 = p[:, None] * q0[None, :]

    # asymmetric
    Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)

    G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
    Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b)

    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

    srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0)
    srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)

    G = log2['T']
    Gb = nx.to_numpy(logb2['T'])
    # check constraints
    np.testing.assert_allclose(G, Gb, atol=1e-06)
    np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
    np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04)  # cf convergence gromov

    np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
    np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)

    # symmetric
    C1 = 0.5 * (C1 + C1.T)
    Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)

    for loss_fun in ['square_loss', 'kl_loss']:
        G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
        Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b)

        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02)  # cf convergence gromov

        srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0)
        srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)

        srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0)

        G = log2['T']
        Gb = nx.to_numpy(logb2['T'])
        # check constraints
        np.testing.assert_allclose(G, Gb, atol=1e-06)
        np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04)  # cf convergence gromov
        np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04)  # cf convergence gromov

        np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
        np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
        np.testing.assert_allclose(srgw, srgw_, atol=1e-07)


@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_semirelaxed_fgw_dtype_device(nx):
    # setup
    n_samples = 5  # nb samples

    mu_s = np.array([0, 0])
    cov_s = np.array([[1, 0], [0, 1]])

    xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)

    xt = xs[::-1].copy()

    rng = np.random.RandomState(42)
    ys = rng.randn(xs.shape[0], 2)
    yt = ys[::-1].copy()

    p = ot.unif(n_samples)

    C1 = ot.dist(xs, xs)
    C2 = ot.dist(xt, xt)

    C1 /= C1.max()
    C2 /= C2.max()

    M = ot.dist(ys, yt)
    for tp in nx.__type_list__:
        print(nx.dtype_device(tp))

        Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp)

        for loss_fun in ['square_loss', 'kl_loss']:
            Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(
                Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True
            )
            fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(
                Mb, C1b, C2b, pb, loss_fun, epsilon=0.1, verbose=True
            )

            nx.assert_same_dtype_device(C1b, Gb)
            nx.assert_same_dtype_device(C1b, fgw_valb)
