# -*- coding: utf-8 -*-
"""
Functions for computing minimal sufficient statistics.
"""
from collections import defaultdict
from .lattice import dist_from_induced_sigalg, insert_join, insert_rv
from .prune_expand import pruned_samplespace
from ..helpers import flatten, parse_rvs, normalize_rvs
from ..math import sigma_algebra
from ..samplespace import CartesianProduct
__all__ = ['info_trim',
'insert_mss',
'mss',
'mss_sigalg',
]
def partial_match(first, second, places):
"""
Returns whether `second` is a marginal outcome at `places` of `first`.
Parameters
----------
first : iterable
The un-marginalized outcome.
second : iterable
The smaller, marginalized outcome.
places : list
The locations of `second` in `first`.
Returns
-------
match : bool
Whether `first` and `second` match or not.
"""
return tuple([first[i] for i in places]) == tuple(second)
def mss_sigalg(dist, rvs, about=None, rv_mode=None):
"""
Construct the sigma algebra for the minimal sufficient statistic of `rvs`
about `about`.
Parameters
----------
dist : Distribution
The distribution which defines the base sigma-algebra.
rvs : list
A list of random variables to be compressed into a minimal sufficient
statistic.
about : list
A list of random variables for which the minimal sufficient static will
retain all information about.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
`rvs` are interpreted as random variable indices. If equal to 'names',
the the elements are interpreted as random variable names. If `None`,
then the value of `dist._rv_mode` is consulted.
Returns
-------
mss_sa : frozenset of frozensets
The induced sigma-algebra of the minimal sufficient statistic.
Examples
--------
>>> d = Xor()
>>> mss_sigalg(d, [0], [1, 2])
frozenset({frozenset(),
frozenset({'000', '011'}),
frozenset({'101', '110'}),
frozenset({'000', '011', '101', '110'})})
"""
mapping = parse_rvs(dist, rvs, rv_mode=rv_mode)[1]
partition = defaultdict(list)
md, cds = dist.condition_on(rvs=about, crvs=rvs, rv_mode=rv_mode)
for marg, cd in zip(md.outcomes, cds):
matches = [o for o in dist.outcomes if partial_match(o, marg, mapping)]
for c in partition.keys():
if c.is_approx_equal(cd):
partition[c].extend(matches)
break
else:
partition[cd].extend(matches)
mss_sa = sigma_algebra(map(frozenset, partition.values()))
return mss_sa
[docs]def insert_mss(dist, idx, rvs, about=None, rv_mode=None):
"""
Inserts the minimal sufficient statistic of `rvs` about `about` into `dist`
at index `idx`.
Parameters
----------
dist : Distribution
The distribution which defines the base sigma-algebra.
idx : int
The location in the distribution to insert the minimal sufficient
statistic.
rvs : list
A list of random variables to be compressed into a minimal sufficient
statistic.
about : list
A list of random variables for which the minimal sufficient static will
retain all information about.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
`rvs` are interpreted as random variable indices. If equal to 'names',
the the elements are interpreted as random variable names. If `None`,
then the value of `dist._rv_mode` is consulted.
Returns
-------
d : Distribution
The distribution `dist` modified to contain the minimal sufficient
statistic.
Examples
--------
>>> d = Xor()
>>> print(insert_mss(d, -1, [0], [1, 2]))
Class: Distribution
Alphabet: ('0', '1') for all rvs
Base: linear
Outcome Class: str
Outcome Length: 4
RV Names: None
x p(x)
0000 0.25
0110 0.25
1011 0.25
1101 0.25
"""
mss_sa = mss_sigalg(dist, rvs, about, rv_mode)
new_dist = insert_rv(dist, idx, mss_sa)
return pruned_samplespace(new_dist)
[docs]def mss(dist, rvs, about=None, rv_mode=None, int_outcomes=True):
"""
Parameters
----------
dist : Distribution
The distribution which defines the base sigma-algebra.
rvs : list
A list of random variables to be compressed into a minimal sufficient
statistic.
about : list
A list of random variables for which the minimal sufficient static will
retain all information about.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of `rvs`
are interpreted as random variable indices. If equal to 'names', the the
elements are interpreted as random variable names. If `None`, then the
value of `dist._rv_mode` is consulted.
int_outcomes : bool
If `True`, then the outcomes of the minimal sufficient statistic are
relabeled as integers instead of as the atoms of the induced
sigma-algebra.
Returns
-------
d : ScalarDistribution
The distribution of the minimal sufficient statistic.
Examples
--------
>>> d = Xor()
>>> print(mss(d, [0], [1, 2]))
Class: ScalarDistribution
Alphabet: (0, 1)
Base: linear
x p(x)
0 0.5
1 0.5
"""
mss_sa = mss_sigalg(dist, rvs, about, rv_mode)
d = dist_from_induced_sigalg(dist, mss_sa, int_outcomes)
return d
def insert_joint_mss(dist, idx, rvs=None, rv_mode=None):
"""
Returns a new distribution with the join of the minimal sufficient statistic
of each random variable in `rvs` about all the other variables.
Parameters
----------
dist : Distribution
The distribution contiaining the random variables from which the joint
minimal sufficent statistic will be computed.
idx : int
The location in the distribution to insert the joint minimal sufficient
statistic.
rvs : list
A list of random variables to be compressed into a joint minimal
sufficient statistic.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
`rvs` are interpreted as random variable indices. If equal to 'names',
the the elements are interpreted as random variable names. If `None`,
then the value of `dist._rv_mode` is consulted.
"""
rvs, _, rv_mode = normalize_rvs(dist, rvs, None, rv_mode)
d = dist.copy()
l1 = d.outcome_length()
rvs = set( tuple(rv) for rv in rvs )
for rv in rvs:
about = list(flatten(rvs-set([rv])))
d = insert_mss(d, -1, rvs=list(rv), about=about, rv_mode=rv_mode)
l2 = d.outcome_length()
idx = -1 if idx > l1 else idx
d = insert_join(d, idx, [[i] for i in range(l1, l2)])
delta = 0 if idx == -1 else 1
d = d.marginalize([i + delta for i in range(l1, l2)])
d = pruned_samplespace(d)
if isinstance(dist._sample_space, CartesianProduct):
d._sample_space = CartesianProduct(d.alphabet)
return d
def info_trim(dist, rvs=None, rv_mode=None):
"""
Returns a new distribution with the minimal sufficient statistics
of each random variable in `rvs` about all the other variables.
Parameters
----------
dist : Distribution
The distribution contiaining the random variables from which the joint
minimal sufficent statistic will be computed.
rvs : list
A list of random variables to be compressed into minimal sufficient
statistics.
rv_mode : str, None
Specifies how to interpret the elements of `rvs`. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
`rvs` are interpreted as random variable indices. If equal to 'names',
the the elements are interpreted as random variable names. If `None`,
then the value of `dist._rv_mode` is consulted.
"""
rvs, _, rv_mode = normalize_rvs(dist, rvs, None, rv_mode)
d = dist.copy()
rvs2 = set( tuple(rv) for rv in rvs )
for rv in rvs:
about = list(flatten(rvs2-{tuple(rv)}))
d = insert_mss(d, -1, rvs=tuple(rv), about=about, rv_mode=rv_mode)
d = pruned_samplespace(d.marginalize(list(flatten(rvs))))
if isinstance(dist._sample_space, CartesianProduct):
d._sample_space = CartesianProduct(d.alphabet)
return d