Package geometry :: Package manifolds :: Module product_manifold
[hide private]
[frames] | no frames]

Source Code for Module geometry.manifolds.product_manifold

 1  from . import DifferentiableManifold, contract, np 
2 3 4 -class ProductManifold(DifferentiableManifold):
5 @contract(components='seq[>=2,N](DifferentiableManifold)', 6 weights='None|array[N](>0)')
7 - def __init__(self, components, weights=None):
8 dim = sum([m.dimension for m in components]) 9 DifferentiableManifold.__init__(self, dimension=dim) 10 self.components = components 11 if weights is None: 12 weights = np.ones(len(components)) 13 self.weights = weights
14 15 @contract(a='seq')
16 - def belongs(self, a):
17 if not len(a) == len(self.components): # XXX: what should I throw? 18 raise ValueError('I expect a sequence of length %d, not %d.' % 19 (len(a), len(self.components))) 20 for x, m in zip(a, self.components): 21 m.belongs(x)
22
23 - def distance(self, a, b):
24 ''' Computes the geodesic distance between two points. ''' 25 distances = [m.distance_(x) for x, m in zip(a, self.components)] 26 distances = np.array(distances) 27 return (distances * self.weights).sum()
28
29 - def logmap(self, a, b):
30 ''' Computes the logarithmic map from base point *a* to target *b*. ''' 31 raise ValueError('Not implemented') # FIXME: finish this 32
33 - def expmap(self, a, v):
34 raise ValueError('Not implemented') # FIXME: finish this
35
36 - def project_ts(self, base, v_ambient):
37 raise ValueError('Not implemented') # FIXME: finish this
38
39 - def __repr__(self):
40 return 'P(%s)' % "x".join([str(x) for x in self.components])
41