Source code for geometry.manifolds.product_manifold
from contracts import contract
import numpy as np
from .differentiable_manifold import DifferentiableManifold
__all__ = ['ProductManifold']
[docs]class ProductManifold(DifferentiableManifold):
@contract(components='seq[>=2,N]($DifferentiableManifold)',
weights='None|array[N](>0)')
def __init__(self, components, weights=None):
dim = sum([m.dimension for m in components])
DifferentiableManifold.__init__(self, dimension=dim)
self.components = components
if weights is None:
weights = np.ones(len(components))
self.weights = weights
[docs] @contract(a='seq')
def belongs(self, a):
if not len(a) == len(self.components): # XXX: what should I throw?
raise ValueError('I expect a sequence of length %d, not %d.' %
(len(a), len(self.components)))
for x, m in zip(a, self.components):
m.belongs(x)
[docs] def distance(self, a, b):
''' Computes the geodesic distance between two points. '''
distances = [m.distance(x, y) for x, y, m in zip(a, b, self.components)]
distances = np.array(distances)
return (distances * self.weights).sum()
[docs] def logmap(self, base, p):
''' Computes the logarithmic map from base point *a* to target *b*. '''
raise ValueError('Not implemented') # FIXME: finish this
[docs] def expmap(self, bv):
raise ValueError('Not implemented') # FIXME: finish this
[docs] def project_ts(self, bv):
raise ValueError('Not implemented') # FIXME: finish this
def __repr__(self):
return 'P(%s)' % "x".join([str(x) for x in self.components])