% This code implements the Bayesian learning algorithm for stent usage
% written by Matt Grennan and Bob Town for the "regulating innovation"
% project.

% Matt Grennan / grennan@wharton.upenn.edu

clear all
if exist('~/IntlMedDevMkts/code')
cd('~/IntlMedDevMkts/code')
else
cd('~/Dropbox/IntlMedDevMkts/code')
end
%addpath('csvwrite_with_headers')

bootstrap = 0 % 0 for full sample; change for bootstrap samples
QALYs = .05
DollarsPerQALY = 100000


%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Import data

%sort tj
%outsheet t j tj age clinical USlater USeventual profit_lifetime pay DES tg zeros Vjt sj p_w Mt p_w_m1 sjg_m1 Jg using EUdata_H_agg.csv, replace
D_agg = importdata('../data/EUdata_agg.csv');
Data_agg = D_agg.data;
%sort tj h
%outsheet t j h tj th Q w_ht s0 sght thg using EUdata_Nht.csv, replace 
D_h = importdata('../data/EUdata_h.csv');
Data_h = D_h.data;

[  t, j, j_true , age, clinical, amc, sj, s0, p, pm1 , ageIV, dum_prod, old_j, USeventual, USe, USe_jt, ...
    sjg , Jg , dum_g , g, tj,th,thg,frac_zeros_jt,Nht,Ht,dum_ght,qm1, ...
    h,th_ght, eta_jht, etaQW_jht, Mht, Mt , p_jht , t_ht  , Mht_ht , t_jht, t_ght, pay, pct_life_profit ] ...
    = ImportData_fcn( Data_agg , Data_h , bootstrap );

%variables for convenience
lnsjs0 = reallog(sj./s0);
lnsjg = reallog(sjg);
Lnsjg = dum_g.*repmat(lnsjg,1,size(dum_g,2));
LnsjgIV = [ dum_g.*repmat(reallog(Jg),1,size(dum_g,2)) , dum_g.*repmat(reallog(Jg).^2,1,size(dum_g,2)) ]  ;
Xe = [ dum_prod ];
%X = [ Lnsjg , -p , Xe ];
%Z = [ LnsjgIV , -pm1 , Xe ];
%sigmaQ2_inv_start = [ 1./var(grpstats(lnsjs0(USeventual==1),j(USeventual==1),'mean')) ; 1./var(grpstats(lnsjs0(USeventual==0),j(USeventual==0),'mean')) ];
sigmaQ2_inv_start = 1./([.3;.3]).^2;

mc_j = .9*grpstats(p_jht,j(tj),'min');
mc_jt = mc_j(j,:);
mc_jht = mc_jt(tj,:);
% 
J = ones(size(age,1),1);
J_j = ones(max(j),1);
%
num_uniq_thg=numel(unique(thg));
num_uniq_th_ght=numel(unique(th_ght));
num_uniq_tj=numel(unique(tj));
num_uniq_t_ht=numel(unique(t_ht));


%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Plot data




%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Estimate parameters -- Demand and Learning
aT=36;
pct_life_profit = grpstats(pct_life_profit,age,'mean');
pct_life_profit = pct_life_profit(1:aT);

% Loop through Logit, NL, QW, QW+learn
%par_L = size(SigmaEU_start,1) + size(SigmaA_start,1);
%par_NL = par_L + g;
%par_QW = par_NL + g;
%par_QWlearn = par_QW + 1;
%LOOP = [ par_L ; par_NL ; par_QW ; par_QWlearn ]; % size of param vector for Logit, NL, QW, QW+learn
%PtEst = zeros(14,4);
%for l = 1:4
    
[SigmaS,QjFES,age_uS] ...
    = fcn_Estimate_Demand(  bootstrap, age, clinical, amc, lnsjs0, lnsjg , p, Xe , Lnsjg , USe_jt, USe, ... 
    qm1, pm1 , LnsjgIV, ageIV, ...
    dum_g, g , num_uniq_thg, num_uniq_th_ght, num_uniq_tj,  ...
    Mht,Mt,h,j, tj,th,th_ght,thg,frac_zeros_jt,eta_jht,etaQW_jht,Nht,Ht,dum_ght,aT,...
    sigmaQ2_inv_start );

%[SigmaS,QjFES,age_uS] ...
%    = fcn_Estimate_Demand_Extensions(  age, clinical, amc, lnsjs0, lnsjg , p, Xe , Lnsjg , USe_jt, USe, ... 
%    qm1, pm1 , LnsjgIV, ageIV, ...
%    dum_g, g , num_uniq_thg, num_uniq_th_ght, num_uniq_tj,  ...
%    Mht,Mt,h,j, tj,th,th_ght,thg,frac_zeros_jt,eta_jht,etaQW_jht,Nht,Ht,dum_ght,aT,...
%    sigmaQ2_inv_start );

SigmaS

aTp=24;
ageT=age(age<=aTp);
USeT=USeventual(age<=aTp);
age_uST=age_uS(age<=aTp,:);
%
figure(1)
hold on
i=[3,4];
for spec=1:size(i,2)
    age_u1(:,spec) = grpstats(age_uST(USeT==1,i(spec)),ageT(USeT==1),'mean');
    plot( grpstats(ageT(USeT==1),ageT(USeT==1),'mean') , age_u1(:,spec) )
end
hold off
saveas(1,'../output/aFEvNN_USe1.pdf','pdf')
%
figure(2)
hold on
for spec=1:size(i,2)
    age_u0(:,spec) = grpstats(age_uST(USeT==0,i(spec)),ageT(USeT==0),'mean');
    plot( grpstats(ageT(USeT==0),ageT(USeT==0),'mean') , age_u0(:,spec) )
end
hold off
saveas(2,'../output/aFEvNN_USe0.pdf','pdf')

% model params to use for subsequent analyses
Sigma=[SigmaS(1:8,5);SigmaS(10,5)]
rho = SigmaS(9,5);
sigmaQ2_inv=1./(SigmaS(end-3:end-2,5)).^2
QjFE=QjFES(:,5);
age_u=age_uS(:,5);
%
thetap = Sigma(1);
Lambda = Sigma(2:size(dum_g,2)+1,:);
sigmaH2 = Sigma(size(dum_g,2)+2:2*size(dum_g,2)+1,:);
sigmaEU2_inv = Sigma(2*size(dum_g,2)+2);
sigmaAc2_inv  = Sigma(2*size(dum_g,2)+3);
sigmaA2_inv  = Sigma(2*size(dum_g,2)+4);
gammaH = Sigma(2*size(dum_g,2)+5);
%sigmaA2q_inv  = Sigma(end-1)/100;
%mu = Sigma(end);
%if sigmaA2q_inv > 0
%    sigmaA2_inv  = sigmaA2_inv + qm1*sigmaA2q_inv;
%end
%
lambda=dum_g*Lambda;
lambda_ght=dum_ght*Lambda;
sigmaH2_jt = dum_g*sigmaH2;
R = sigmaH2_jt./ (2*(1-lambda).^2);
sigma2_jt = 1./( USe_jt*sigmaQ2_inv + sigmaEU2_inv + amc.*sigmaA2_inv + clinical.*sigmaAc2_inv );
w_signals_jt = w_signals_fcn(clinical,amc,sigmaEU2_inv,sigmaA2_inv,sigmaAc2_inv,gammaH) .* sigma2_jt;
d_jt = lnsjs0 - lnsjg.*lambda - (1-lambda).*R - .5*(w_signals_jt.*sigma2_jt)./(1-lambda);
xi_jt = d_jt + thetap*p + rho*.5*sigma2_jt - dum_prod*QjFE;

xiT=xi_jt(age<=aTp);
figure
hold on
plot(grpstats(ageT(USeT==1),ageT(USeT==1),'mean'),grpstats(xiT(USeT==1).^2,ageT(USeT==1),'mean'))
plot(grpstats(ageT(USeT==0),ageT(USeT==0),'mean'),grpstats(xiT(USeT==0).^2,ageT(USeT==0),'mean'))
hold off

% scaling using QALY studies?
ATTj = reallog(1+exp(QjFE)) ./ (exp(QjFE)./(1+exp(QjFE)));
scale = QALYs*DollarsPerQALY/ median(ATTj(USe==2));


%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Estimate parameters -- Pricing

[pB_jt,mc_jt,adjAV_jt,Ep_jt] ...
    = fcn_Estimate_Supply( dum_prod,  J,J_j, Sigma , dum_g , lnsjs0 , lnsjg , USe_jt , sigmaQ2_inv , amc , clinical , ...
        tj , etaQW_jht , eta_jht , thg , dum_ght , th_ght , th , Mht , p_jht , p , t_ht , Mht_ht , Mt , ...
        num_uniq_thg, num_uniq_th_ght, num_uniq_tj, num_uniq_t_ht, ...
        scale , j );

mc_jht = mc_jt(tj,:);


%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Partial Eq value of Risk
for i=1:3
    muQ_mod = (i-2)*1.28;
[checks,Risk,Risk_qw(i,:),PctStent(i,:),CSps(i,:),PSps(i,:),TSps(i,:),EPDps(i,:)] = ...
    fcn_Counterfactual_PartialEqRisk( Sigma, Mt,Mht,Mht_ht,QjFE+muQ_mod,xi_jt,age,j,USe_jt,sigmaQ2_inv,dum_prod,rho,tj,etaQW_jht,eta_jht,...
    t, t_jht, t_ght, thg, th, th_ght, num_uniq_thg, num_uniq_th_ght, num_uniq_tj, num_uniq_t_ht, ...
    mc_jht, scale, pB_jt , p, dum_g , lnsjs0 , lnsjg , dum_ght , t_ht , mc_jt , USe , pay );
end


%% --------------------------------------------------------------------------
%--------------------------------------------------------------------------
% Save results

append=[2,3,4,6,7];
Parameters_Dall = [ SigmaS(1:5,append) ; SigmaS(9,append).*SigmaS(1,append) ; SigmaS(6:11,append) ; mean(QjFES(:,append),1) ; SigmaS(12:end,append)  ];
Parameters_Dall = [ bootstrap*ones(size(Parameters_Dall,1),1) , [1:size(Parameters_Dall,1)]' , Parameters_Dall ]
filename = sprintf('../temp/Parameters_Dall_%d.csv',bootstrap);
csvwrite(filename,Parameters_Dall)

Parameters_D = [ Sigma(1:8) ; rho.*thetap ; Sigma(9) ; 1./sigmaQ2_inv.^.5 ; scale ; SigmaS(end-1:end,5) ];
Parameters_D = [ bootstrap*ones(size(Parameters_D,1),1) , [1:size(Parameters_D,1)]' , Parameters_D ]
filename = sprintf('../temp/Parameters_D_%d.csv',bootstrap);
csvwrite(filename,Parameters_D)

% Qj1 estimates after EU trials
agemin=grpstats(age,j,'min');
Qj1 = QjFE + xi_jt(age==agemin(j));
j_j=grpstats(j_true,j,'min');
Parameters_QjFE = [ bootstrap*ones(size(QjFE,1),1) , j_j , QjFE , Qj1 ];
mean(QjFE)
std(QjFE)
filename = sprintf('../temp/Parameters_QjFE_%d.csv',bootstrap);
csvwrite(filename,Parameters_QjFE)

%ageU = grpstats(age_u(age<=aT),age(age<=aT),'mean');
Parameters_ageU = [ bootstrap*ones(size(age_u1,1),1) , [1:size(age_u1,1)]' , age_u1 , age_u0 ]
filename = sprintf('../temp/Parameters_ageU_%d.csv',bootstrap);
csvwrite(filename,Parameters_ageU)

Cfcl_PEqRisk = [ checks-1 ; Risk ; Risk_qw ; PctStent ; TSps ; EPDps];
Cfcl_PEqRisk = [ bootstrap*ones(size(Cfcl_PEqRisk,1),1) , [1:size(Cfcl_PEqRisk,1)]' , Cfcl_PEqRisk ]
filename = sprintf('../temp/Cfcl_PEqRisk_%d.csv',bootstrap);
csvwrite(filename,Cfcl_PEqRisk)

Parameters_S = [ bootstrap*ones(size(pB_jt,1),1) , [1:size(pB_jt,1)]' , pB_jt , mc_jt ];
mean(pB_jt)
std(pB_jt)
filename = sprintf('../temp/Parameters_S_%d.csv',bootstrap);
csvwrite(filename,Parameters_S)

adjAVtotal = adjAV_jt + p - mc_jt;
Parameters_S_sum = [ bootstrap*ones(2,1) , [1:2]' , [mean(Ep_jt(dum_g(:,1)==0));mean(Ep_jt(dum_g(:,1)==1))] , [std(Ep_jt(dum_g(:,1)==0));std(Ep_jt(dum_g(:,1)==1))] , ...
    [mean(adjAVtotal(dum_g(:,1)==0));mean(adjAVtotal(dum_g(:,1)==1))] , [std(adjAVtotal(dum_g(:,1)==0));std(adjAVtotal(dum_g(:,1)==1))] , [mean(mc_jt(dum_g(:,1)==0));mean(mc_jt(dum_g(:,1)==1))] , ...
    [std(mc_jt(dum_g(:,1)==0));std(mc_jt(dum_g(:,1)==1))] , [mean(pB_jt(dum_g(:,1)==0));mean(pB_jt(dum_g(:,1)==1))] , [std(pB_jt(dum_g(:,1)==0));std(pB_jt(dum_g(:,1)==1))] ]
filename = sprintf('../temp/Parameters_S_sum_%d.csv',bootstrap);
csvwrite(filename,Parameters_S_sum)

%
exit


