(* :Title: MaxEntSimplexDistribution *)

(* :Author: Mark Fisher *)

(* :Version: 0.1 November 2005 *)

(* :Mathematica Version: 5.0
	I use "f @@@ expr" syntax for "Apply[f, expr]".
	Maybe other things as well. *)

(* :Summary: Maximum entropy distribution on a simplex *)

(* :References:
	Mark Fisher (2005) Maximum Entropy on a Simplex: An Expository Note.
		Unpublished
	E. T. Jaynes (2003) Probability Theory: The Logic of Science.
		Cambridge University Press.
*)

(* :Notes:
	There are a few tricky issues regarding LambdaOK and LambdaLimit dealing
	with precision. As a consequence, PartitionFunction[{1, 1.}] returns an
	infinite-precision result while PartitionFunction[{1., 1}] a machine-
	precision result.

	At some point I might try to compile
	Mean[MaxEntSimplexDistribution[lambda]] for cases that fail LambdaOK.
	*)

BeginPackage["MaxEntSimplexDistribution`", {
	"Utilities`FilterOptions`",
	"Statistics`Common`DistributionsCommon`",  (* for PDF and RandomArray *)
	"Statistics`MultiDescriptiveStatistics`"}] (* for Covariance *)

MaxEntSimplexDistribution::usage = "MaxEntSimplexDistribution[\[Lambda]]
represents the maximum entropy distribution on a simplex given the
parameter vector \[Lambda]. PDF[MaxEntSimplexDistribution[\[Lambda]], x]
returns the PDF, where x is a vector of the same length as \[Lambda].
PDF[MaxEntSimplexDistribution[\[Lambda]], x, pos] returns the marginal PDF
for x[[pos]], where pos is a list of positions and \[Lambda] and x are from
the joint distribution. PDF[MaxEntSimplexDistribution[\[Lambda]], x, sum]
returns the conditional PDF for x (and its associated \[Lambda]), given the
sum of conditioning variables. Mean, Covariance, Variance,
StandardDeviation, and Entropy are defined for
MaxEntSimplexDistribution[\[Lambda]], as are Random and RandomArray. If
Length[\[Lambda]] > 1, then Random calls MaxEntSimplexGibbs[\[Lambda], 2 *
Length[\[Lambda]]] and returns the last draw; otherwise, Random computes
iid draws directly."

PartitionFunction::usage = "PartitionFunction[\[Lambda]] returns the
partition function given the vector lambda. PartitionFunction[b][\[Lambda]]
computes the partition function with an upper bound of b. PartitionFunction
is the normalization factor for PDF[MaxEntSimplexDistribution[\[Lambda]],
x]."

MaxEntSimplexGibbs::usage = "MaxEntSimplexGibbs[\[Lambda], n] returns n
draws from the Gibbs sampler for the maximum entropy distribution on a
simplex given the vector lambda. The sampler is initialized with
Mean[MaxEntSimplexDistribution[\[Lambda]]]. If Length[\[Lambda]] == 1, then
the draws are iid from the univariate distribution."

MaxEntSimplexInvertMean::usage = "MaxEntSimplexInvertMean[mean] returns the
\[Lambda] vector associated with the given mean. Options can be passed to
FindRoot, which MaxEntSimplexInvertMean calls. In addition,
MaxEntSimplexInvertMean takes the option StartingValues which can be used
to pass a list of starting values to FindRoot. The default setting is
StartingValues -> Automatic."

Entropy::usage = "Entropy[MaxEntSimplexDistribution[\[Lambda]]] returns the
entropy of the given distribution, where \[Lambda] is a vector."

LambdaOK::usage = "LambdaOK[\[Lambda]] returns True if no component of
\[Lambda] is zero and if no two components are equal. LambdaOK is used to
trap arguments to a number of functions related to maximum entropy."

LambdaLimit::usage = "LambdaLimit[\[Lambda], fun] computes the limit for
fun[\[Lambda]], where fun is typically PartitionFunction or its gradient or
Hessian."

MeanOK::usage = "MeanOK[mean] returns True if no component of mean equals
(1 - Tr[mean]) and if no two components are equal. MeanOK is used to trap
arguments to MaxEntSimplexInvertMean."

MeanLimit::usage = "MeanLimit[mean] computes the limit for when mean fails
MeanOK. MeanLimit is called by MaxEntSimplexInvertMean."

StartingValues::usage = "StartingValues is an option for
MaxEntSimplexInvertMean. The default value is StartingValues -> Automatic,
which produces Range[Length[num]] where num is the number of paramters to
estimated."

Begin["`Private`"]

(* PartitionFunction, LambdaOK, and LambdaLimit *)

PartitionFunction[b_][lambda_List?LambdaOK] :=
	With[{n = Length[lambda]},
	1/Product[lambda[[i]], {i, n}] -
	Sum[Exp[-b lambda[[i]]]/
		(lambda[[i]] *
		Product[If[j==i, 1, (lambda[[j]] - lambda[[i]])], {j, n}]),
		{i, n}]
	]

PartitionFunction[b_][lambda:{0 ..}] :=
	With[{n = Length[lambda]},
	b^n/n!
	]

PartitionFunction[b_][lambda_List] :=
	LambdaLimit[lambda, PartitionFunction[b]]

PartitionFunction[lambda_List] := PartitionFunction[1][lambda]

PartitionFunction[_][__] := $Failed

LambdaOK[lambda_] :=
	With[{u = Union[lambda, SameTest -> Equal]},
	Length[u] == Length[lambda] &&
	FreeQ[Thread[u == 0], True]
	]

(* this function takes care of LambdaOK exceptions *)
LambdaLimit[lambda_List, fun_] :=
	Module[{lam, x, ulam, upos, groups, rules, zpos},
	lam = Array[x, Length[lambda]];
	ulam = Union[lambda, SameTest -> Equal];
	upos = Function[u, Flatten[Position[lambda, _?(#==u&)]]] /@ ulam;
	groups = Map[x, DeleteCases[{Rest[#], First[#]} & /@ upos, {_, {}}], {-1}];
	rules = (Rule @@@ Flatten[(Thread /@ groups), 1]);
	zpos = Position[ulam, _?(# == 0&)];
	If[zpos != {}, AppendTo[rules, lam[[ upos[[ zpos[[1, 1]] ]][[1]] ]] -> 0]];
	Fold[Limit[#1, #2] &, fun[lam], rules] /. Thread[lam -> lambda]
	]

(* PDF: joint, conditional, and marginal *)

MaxEntSimplexDistribution /:
PDF[MaxEntSimplexDistribution[lambda_List], x_List] /;
		Length[lambda] == Length[x] :=
	Exp[-lambda.x]/PartitionFunction[lambda]

(* marginal distribution of x[[pos]] *)
MaxEntSimplexDistribution /:
PDF[MaxEntSimplexDistribution[lambda_List?LambdaOK], x_List, pos_List] /;
		Length[lambda] == Length[x] :=
	Exp[-lambda[[pos]].x[[pos]]] *
	PartitionFunction[1 - Tr[x[[pos]]]][ Complement[lambda, lambda[[pos]]] ] /
	PartitionFunction[lambda]

MaxEntSimplexDistribution /:
PDF[MaxEntSimplexDistribution[lambda_List], x_List, pos_List] /;
		Length[lambda] == Length[x] :=
	Exp[-lambda[[pos]].x[[pos]]] *
	LambdaLimit[lambda,
		PartitionFunction[1 - Tr[x[[pos]]]][ Complement[#, #[[pos]]] ] /
		PartitionFunction[#]&]

(* conditional on the sum of the conditioning variables *)
MaxEntSimplexDistribution /:
PDF[MaxEntSimplexDistribution[lambda_List], x_List, sum_] /;
		Length[lambda] == Length[x] :=
	Exp[-lambda.x]/PartitionFunction[1 - sum][lambda]

MaxEntSimplexDistribution /:
PDF[MaxEntSimplexDistribution[__], __] := $Failed


(* Mean, Covariance, Variance, StandardDeviation, Entropy *)

MaxEntSimplexDistribution /:
Mean[MaxEntSimplexDistribution[{}]] := {}

MaxEntSimplexDistribution /:
Mean[MaxEntSimplexDistribution[lambda_List?LambdaOK]] :=
	Module[{lam, x},
	lam = Array[x, Length[lambda]];
	-D[Log[PartitionFunction[lam]], {lam}] /. Thread[lam -> lambda]
	]

MaxEntSimplexDistribution /:
Mean[MaxEntSimplexDistribution[lambda_List]] :=
	LambdaLimit[lambda, -D[Log[PartitionFunction[#]], {#}]&]

(* compiled for machine precision lambda *)
MaxEntSimplexDistribution /:
Mean[MaxEntSimplexDistribution[lambda_List?LambdaOK]] /;
		VectorQ[lambda, NumericQ] &&
		Precision[lambda] == MachinePrecision :=
	compiledmean[Length[lambda]] @@ lambda

MaxEntSimplexDistribution /:
Covariance[MaxEntSimplexDistribution[lambda_List?LambdaOK]] :=
	Module[{lam, x},
	lam = Array[x, Length[lambda]];
	D[Log[PartitionFunction[lam]], {lam}, {lam}] /. Thread[lam -> lambda]
	]

MaxEntSimplexDistribution /:
Covariance[MaxEntSimplexDistribution[lambda_List]] :=
	LambdaLimit[lambda, D[Log[PartitionFunction[#]], {#}, {#}]&]

MaxEntSimplexDistribution /:
Variance[MaxEntSimplexDistribution[lambda_List]] :=
	Transpose[Covariance[MaxEntSimplexDistribution[lambda]], {1,1}]

MaxEntSimplexDistribution /:
StandardDeviation[MaxEntSimplexDistribution[lambda_List]] :=
	Sqrt[Variance[MaxEntSimplexDistribution[lambda]]]

MaxEntSimplexDistribution /:
Entropy[MaxEntSimplexDistribution[lambda_List]] :=
	lambda.Mean[MaxEntSimplexDistribution[lambda]] +
		Log[PartitionFunction[lambda]]

MaxEntSimplexDistribution /:
(Mean | Covariance | Variance | StandardDeviation |
	Entropy)[MaxEntSimplexDistribution[__]] := $Failed

(* MaxEntSimplexInvertMean *)

(* numerically solve for lambda given the mean vector *)
Options[MaxEntSimplexInvertMean] = {StartingValues -> Automatic}

MaxEntSimplexInvertMean[{}, ___] := {}

MaxEntSimplexInvertMean[mean_?(VectorQ[#, NumericQ]&), opts___?OptionQ] /;
		MeanOK[mean] :=
	Module[{len = Length[mean], lam, x, fropts, start},
	fropts = FilterOptions[FindRoot, opts];
	start = StartingValues /. {opts} /. Options[MaxEntSimplexInvertMean];
	If[start == Automatic, start = Range[len]];
	If[Length[start] != len, Return[$Failed]];
	lam = Array[x, len];
	lam /. FindRoot[
		Mean[MaxEntSimplexDistribution[lam]] - mean,
		Evaluate[Sequence @@ Transpose[{lam, start}], fropts]
		]
	]

MaxEntSimplexInvertMean[mean_?(VectorQ[#, NumericQ]&), opts___?OptionQ] :=
	MeanLimit[mean, opts]

MeanLimit[mean_List, opts___?OptionQ] :=
	Module[{lam, x, umean, upos, groups, rules, zpos, temp,
		utemp, frpos, len, fropts, start, meanexpr, meandiff},
	lam = Array[x, Length[mean]];
	umean = Union[mean, SameTest -> Equal];
	upos = Function[u, Flatten[Position[mean, _?(#==u&)]]] /@ umean;
	groups = Map[x, DeleteCases[{Rest[#], First[#]} & /@ upos, {_, {}}], {-1}];
	rules = (Rule @@@ Flatten[(Thread /@ groups), 1]);
	zpos = Position[umean, _?(# == 1-Tr[mean]&)];
	If[zpos != {}, AppendTo[rules, lam[[ upos[[ zpos[[1, 1]] ]][[1]] ]] -> 0]];
	temp = lam //. rules;
	utemp = Union[DeleteCases[temp, 0]];
	frpos = utemp /. x[i_] -> i;
	If[frpos == {},
		(* then *)
		temp,
		(* else *)
		len = Length[utemp];
		fropts = FilterOptions[FindRoot, opts];
		start = StartingValues /. {opts} /. Options[MaxEntSimplexInvertMean];
		If[start == Automatic, start = Range[len]];
		If[Length[start] != len, Return[$Failed]];
		meanexpr = Fold[Limit[#1, #2] &,
			-D[Log[PartitionFunction[lam]], {lam}], rules]/. Thread[lam -> temp];
		meandiff = meanexpr - mean;
		temp /. FindRoot[Evaluate @ meandiff[[frpos]],
			Evaluate @ Transpose[{utemp, start}], Evaluate @ fropts]
		]
	]

MaxEntSimplexInvertMean[__] := $Failed

MeanOK[mean_] :=
	With[{u = Union[mean, SameTest -> Equal]},
	Length[u] == Length[mean] &&
	FreeQ[Thread[u == (1 - Tr[mean])], True]
	]

(* Random, RandomArray, Gibbs sampler *)

(* univariate *)
MaxEntSimplexGibbs[{lambda_?NumericQ}, n_Integer?Positive] :=
	Table[maxent[lambda], {n}]

MaxEntSimplexDistribution /:
Random[MaxEntSimplexDistribution[{lambda_?NumericQ}]] :=
	maxent[lambda]

MaxEntSimplexDistribution /:
Random[MaxEntSimplexDistribution[{lambda_?NumericQ}], b_] :=
	maxentb[lambda, 1-b]

MaxEntSimplexDistribution /:
RandomArray[MaxEntSimplexDistribution[{lambda_?NumericQ}], n_Integer?Positive] :=
	Table[maxent[lambda], {n}]

(* multivariate; initialized with
	Mean[MaxEntSimplexDistribution[lambda]] *)
MaxEntSimplexGibbs[lambda_?(VectorQ[#, NumericQ]&), n_Integer?Positive] /;
		LambdaOK[lambda] :=
	Module[{slist, rmat, len},
	len = Length[lambda];
	(* used compiled version for efficiency *)
	slist = compiledmean[len] @@ lambda;
	rmat = Table[If[i == j, 0, 1], {i, len}, {j, len}];
	Table[
		slist[[i]] = maxentb[lambda[[i]], 1 - slist . rmat[[i]]],
		{n}, {i, len}]
	] /; Length[lambda] > 1

MaxEntSimplexGibbs[__] := $Failed

(* calls MaxEntSimplexGibbs 2* Length[lambda] times
	and returns the last one *)
MaxEntSimplexDistribution /:
Random[MaxEntSimplexDistribution[lambda_?(VectorQ[#, NumericQ]&)]] :=
	Last @ MaxEntSimplexGibbs[lambda, 2 * Length[lambda]]

MaxEntSimplexDistribution /:
RandomArray[MaxEntSimplexDistribution[lambda_?(VectorQ[#, NumericQ]&)],
		n_Integer?Positive] :=
	Table[Last @ MaxEntSimplexGibbs[lambda, 2 * Length[lambda]], {n}]

maxent = Compile[{lam},
	With[{u = Random[]},
	If[Abs[lam] < 0.0001,
		u - 0.5 * u * (1 - u) * lam,
		-(Log[1 + (E^(-lam) - 1) * u]/lam)
		]
	]]

maxentb = Compile[{lam, b},
	With[{u = Random[]},
	If[Abs[lam] < 0.0001,
		b * u - 0.5 * b^2 * u * (1 - u) * lam,
		-(Log[1 + (E^(-b * lam) - 1) * u]/lam)
		]
	]]

(* for numerical efficiency *)
compiledmean[i_Integer?Positive] :=
	compiledmean[i] = (* memoize *)
	Block[{x},
	Compile @@ {Array[x, i],
		Mean[MaxEntSimplexDistribution[Array[x, i]]]}
	]

End[]
EndPackage[]
