blahut.c

#include <assert.h>
#include "blahut.h"
 
/* print useful information about what is going on */
#define DEBUG

/* print result in calculation for algorithm correctness check*/
/* #define DEBUG_Calc */

/* print warning message */
#define DEBUG_PRINT_WARNING 

/* Maximum amount of iteration allowed in calculating a point on C(E) curve */
#define DEFAULT_MAX_IT UINT_MAX

/* The range within wich 2 double values are judged as the same */
#define DOUBLE_COMP_LIMIT 1e8

/* 
 * Many systems already have my_log2 function. If not, remove the 
 * ifdef, endif preprocessor to enable this function.
 */
inline static double my_log2(const double value)
{
    return log10(value)/log10(2);
}

/*
 * Natural Logarithm.
 */
inline static double my_loge(const double value)
{
    return log(value);
}

static int 
vector_isnonneg(const gsl_vector* vec)
{
    unsigned int i;
    for (i=0;i<vec->size;i++) {
	if (gsl_vector_get(vec,i) < 0) {
	    return 0;
	}
    }
    return 1;
}

static int 
matrix_isnonneg(const gsl_matrix* mat)
{
    unsigned int j,k;
    for (j=0;j<mat->size1;j++) {
	for (k=0;k<mat->size2;k++) {
	    //printf("%g\n",gsl_matrix_get(mat,j,k));
	    if (gsl_matrix_get(mat,j,k) < 0.0) {
		return 0;
	    }
	}
    }
    return 1;
}

blahut_cap * 
blahut_cap_init( const gsl_matrix* Q, 
		   const gsl_vector* e )
{
    unsigned int i=0, k=0;
    blahut_cap * cap = (blahut_cap*) malloc (sizeof(blahut_cap));
    if (!cap) {
	fprintf(stderr, "(E) Not enough memory when allocating a blahut_cap.\n");
	exit(1);
    }
    memset(cap, 0, sizeof(blahut_cap));

    /* Check validity of Q and e */
    if (Q == NULL || e == NULL) {
	fprintf(stderr, "(E) Q or e is NULL pointer.\n");
	exit(1);
    }
    if (Q->size1 != e->size) {
	fprintf(stderr, "(E) Q's # rows is not equal to e's size.\n");
	exit(1);
    } else if (Q->size2 <= 0) {
	fprintf(stderr, "(E) Q's # columns should not be negative.\n");
	exit(1);
    } else if (!vector_isnonneg(e) || !matrix_isnonneg(Q)) {
	fprintf(stderr, "(E) Q or e contains negative elements.\n");
	exit(1);
    }

    /* Check the validity of Q and e */
    double sum_Q=0;

    /* each row of Q should sum to 1 */
    for (i=0;i<Q->size1;i++) {
	sum_Q=0;
	for (k=0;k<Q->size2;k++) {
	    sum_Q += gsl_matrix_get(Q, i, k);
	}
	if (fabs(sum_Q - 1.0) > DOUBLE_COMP_LIMIT) {
	    fprintf(stderr, "(E) Sum over row %d of Q seems not to be 1.\n", k);
	    exit(2);
	}
    }

    cap->Q = (gsl_matrix*)Q;
    cap->e = (gsl_vector*)e;
    
    /* Process Q and e */
    cap->numIn = Q->size1;
    cap->numOut = Q->size2;

    /* Initialize p, P, c */
    cap->p = gsl_vector_alloc(cap->numIn);
    gsl_vector_set_all(cap->p, 1.0/cap->numIn); /* init. the input distribution
					      as uniform */
    cap->P = gsl_matrix_calloc(cap->numOut, cap->numIn); /* numOut x numIn */
    cap->c = gsl_vector_calloc(cap->numIn); /* numIn x 1 */

    /* The default values */
    cap->unit = BITS;
    cap->s_L = 0.0;
    cap->s_U = 1e4;
    cap->s_d = 0.001;

    cap->epsilon = 1e-5;
    cap->maxNumIt = DEFAULT_MAX_IT;

    return cap;
}

void 
blahut_cap_free( blahut_cap* cap)
{
    gsl_matrix_free(cap->P);
    gsl_vector_free(cap->p);
    gsl_vector_free(cap->c);
    if (cap->ce_curve.p) {
	gsl_matrix_free(cap->ce_curve.p);
    }
    if (cap->ce_curve.E) {
	gsl_vector_free(cap->ce_curve.E);
    }
    if (cap->ce_curve.C) {
	gsl_vector_free(cap->ce_curve.C);
    }

    free(cap);
}

/* Calculate:
 * sum_j {p_j * Q_{k|j}} */
inline static double
sum_p_Q (const blahut_cap * cap, int k)
{
    register int j = 0;
    register double sum = 0;
    for (j=0; j<cap->numIn; j++) {
	sum += gsl_vector_get(cap->p, j) * gsl_matrix_get(cap->Q, j, k);
    }
#ifdef DEBUG_Calc
    fprintf(stdout, "(I) q_Y[%d] = %g\n", k, sum);
#endif
    return sum;
}

/* Calculate the sum of the first term of exp(...) in c_j
 * expression */
inline static double 
sum_Q_log (const blahut_cap * cap, int j)
{
    register int k = 0;
    register double Q_kj;
    register double sum = 0;

    for (k=0; k<cap->numOut; k++) {
	Q_kj = gsl_matrix_get(cap->Q, j, k);
	/* use the convention that 0log0 = 0 */
	sum += (Q_kj == 0 ? 0 : Q_kj * my_loge (Q_kj/sum_p_Q(cap,k)));
    }

#ifdef DEBUG_Calc
    fprintf(stdout, "sum_k Q_k|%d * log(Q_k|%d/q_Y[k]) = %g\n",
	    j,j,sum);
#endif
    return sum;
}

/* Calculate c_j over all j */
static blahut_cap * calc_c_j ( blahut_cap * cap )
{
    int j;
    for (j=0; j<cap->numIn; j++) {
	gsl_vector_set(cap->c, j, 
		exp(sum_Q_log(cap, j)
		    - cap->s * gsl_vector_get(cap->e,j)));
#ifdef DEBUG_Calc
	fprintf(stdout, "(I) c[%d] = %g\n", j, gsl_vector_get(cap->c,j));
#endif
    }
    return cap;
}

inline static double calc_I_L (blahut_cap * cap)
{
    int j;
    double sum=0;
    for (j=0; j<cap->numIn; j++) {
	sum += gsl_vector_get(cap->p,j) * gsl_vector_get(cap->c,j);
    }
    cap->I_L = my_loge(sum);
    return cap->I_L;
}

inline static double calc_I_U (blahut_cap * cap)
{
    cap->I_U = my_loge(gsl_vector_max ( cap->c ));
    return cap->I_U;
}

static blahut_cap * update_p (blahut_cap * cap)
{
    double sum=0;
    int j;
    for(j=0;j<cap->numIn;j++) {
	sum += gsl_vector_get(cap->p,j) * gsl_vector_get(cap->c,j);
    }

    for (j=0;j<cap->numIn; j++) {
	gsl_vector_set(cap->p,j, 
		gsl_vector_get(cap->p,j)
		* (gsl_vector_get(cap->c,j) / sum));
    }

    return cap;
}

static double calc_E(blahut_cap * cap)
{
    int j;
    double sum = 0;
    for(j=0;j<cap->numIn; j++) {
	sum += gsl_vector_get(cap->p,j) * gsl_vector_get(cap->e,j);
    }
    cap->E = sum;
    return sum;
}

static double calc_C(blahut_cap * cap)
{
    cap->C = cap->s * cap->E + cap->I_L;

    if (cap->unit == NATS) {
	return cap->C;
    } else if (cap->unit == BITS) {
	cap->C = my_log2(exp(cap->C));
	return cap->C;
    } else {
	fprintf(stderr, "(W) Wrong information unit specified:%d, "
		"using the default: %d\n", cap->unit, BITS);
	cap->unit = BITS;
	calc_C(cap);
    }
    return cap->C;
}
    
blahut_cap *
blahut_cap_calc( blahut_cap * cap )
{
    for(cap->it=0 ; cap->it < cap->maxNumIt ; cap->it++ ) {
	calc_c_j(cap);
	if (calc_I_U(cap) - calc_I_L(cap) < cap->epsilon) {
	    calc_E(cap);
	    calc_C(cap);
	    break;
	} else if ( isnan(cap->I_U) || isnan(cap->I_L) ) {
#ifdef DEBUG_PRINT_WARNING
	    fprintf(stdout, "(I) I_U or I_L is NaN, terminate loop.\n");
#endif
	    break;
	}
	update_p(cap);
    }

    /* to check if for() is terminated by a break or by
     * cap->it >= cap->maxNumIt */
    if (cap->it >= cap->maxNumIt) {
	/* for() is terminated by exceeding the max # iterations */
	calc_E(cap);
	calc_C(cap);
	cap->exceedsMaxNumIt = 1;
    }

    return cap;
}

inline blahut_cap * 
blahut_cap_set_p_uniform( blahut_cap * cap )
{
    gsl_vector_set_all(cap->p, 1.0/cap->p->size);
    return cap;
}

static blahut_ce_curve * 
ce_curve_init( blahut_cap * cap )
{
    cap->ce_curve.len = (unsigned int)floor((cap->s_U - cap->s_L)/cap->s_d);
    cap->ce_curve.p = gsl_matrix_calloc(cap->ce_curve.len, cap->numIn);
    cap->ce_curve.E = gsl_vector_calloc(cap->ce_curve.len);
    cap->ce_curve.C = gsl_vector_calloc(cap->ce_curve.len);

    return &(cap->ce_curve);
}

blahut_ce_curve  
blahut_cap_iterate_over_s( blahut_cap * cap, const char* filename)
{
    /* Initialize the field 'cap->ce_curve' for storing data */
    ce_curve_init(cap);

#ifdef DEBUG
    fprintf(stdout, "(I) Calculating C(E) curve ...\n");
#endif

    unsigned int NumS = cap->ce_curve.len; /* # samples on the curve */
    unsigned int i,j;
    double step = cap->s_d;
    /* Begin iterating over s */
    for (i=0,cap->s = cap->s_L; i<NumS; i++, cap->s+=step) {
	//printf("%d, %g\n",i,cap->s);
	blahut_cap_set_p_uniform(cap);
	blahut_cap_calc(cap);

	/* stop iterating */
	if (cap->C == 0 || cap->E == 0) {
#ifdef DEBUG_PRINT_WARNING
	    fprintf(stdout, "(W) C = 0 or E = 0 encountered, break the iteration.\n");
#endif
	    break;
	}

	/* Store the result */
	{
	    gsl_vector_view p_view = gsl_matrix_row(cap->ce_curve.p, i);
	    gsl_vector_memcpy(&p_view.vector, cap->p);

	    gsl_vector_set(cap->ce_curve.E, i, cap->E);
	    gsl_vector_set(cap->ce_curve.C, i, cap->C);
	}
    }
#ifdef DEBUG
    fprintf(stdout, "(I) Finished.\n");
#endif

    /* Write to file */
    if (filename != NULL) {
#ifdef DEBUG
	fprintf(stdout, "(I) Writing data to file ...\n");
#endif
	FILE * file = fopen(filename,"w");
	if (file == NULL) {
	    fprintf(stderr, "(E) Error opening file \"%s\" for writing.\n", 
		    filename);
	    exit(1);
	}
	for (i=0; i<NumS; i++) {
	    fprintf(file, "%g ", gsl_vector_get(cap->ce_curve.E, i));
	    fprintf(file, "%g ", gsl_vector_get(cap->ce_curve.C, i));
	    for (j = 0; j < cap->numIn; j++) {
		fprintf(file, "%g ", gsl_matrix_get(cap->ce_curve.p, i, j));
	    }
	    fprintf(file, "\n");
	}
	fclose(file);
#ifdef DEBUG
	fprintf(stdout, "(I) Finished.\n");
#endif
    }

    return cap->ce_curve;
}

blahut_cap * 
blahut_cap_setSRange(blahut_cap * cap, double s_L, double s_U, double step)
{
    if (s_U < s_L) {
	fprintf(stderr, "(E) blahut_cap_setSRange:" 
		"The upper limit %g is less than the lower limit %g.\n",
		s_U, s_L);
	exit(3);
    }
    cap->s_L = s_L;
    cap->s_U = s_U;
    cap->s_d = step;

    return cap;
}



Kefei Lu 2008-05-15