%% calcBounds
%
% Calculates lower and upper bounds on counterfactual earnings, zL and zU, 
% given optimal CF earnings z, elasticity e, adjustment costs phi, and a
% structure with information on the tax system. 

function [zL, zU] = calcBounds(z, e, phi, tax, opt)
    if phi > 0
        uStar = uOpt(z, e, tax);
        n = z./(1-tax.t0).^e;
        zLStart = max(z-phi, 0);

        if sum(isnan(calcUDiff(uStar, zLStart, n,e,phi, tax)));
            fprintf('Trouble for e, phi: %6.5f, %7.4f\n', e, phi);
            return
        end
        
        % Try to solve 
        try
            [zL,val,flag] =fsolve(@(zLower)...
                calcUDiff(uStar, zLower,n,e,phi, tax), zLStart, opt);
        catch
            val = ones(size(z));
            flag = 0;
        end
        % If fail, first solve point-by-point
        if ~(flag==1 || max(abs(val))<1e-10)
            valV = val;
            flagV = zeros(length(valV),1);
            for zz = 1:length(n)
                [zL(zz),valV(zz),flagV(zz)] =fzero(@(zLower)...
                    calcUDiff(uStar(zz), zLower,n(zz),e,phi, tax), zLStart(zz), opt);
            end
            
            % For points where we solve, see if zero is a lower bound
            troublePts = find(flagV~=1);
            diffAtZero = calcUDiff(uStar(troublePts), 0, n(troublePts),e, phi, tax);
            zL(troublePts(diffAtZero<=0)) = 0;
            flagV(troublePts(diffAtZero<=0)) = 1;
            troublePts(diffAtZero<=0) = [];
            
            % Bracket search in [0, z] for remaining values     
            if ~isempty(troublePts)
                troublePts = reshape(troublePts, 1, length(troublePts));
                for ii = troublePts
                    try
                    [zL(ii), valV(ii), flagV(ii)] = fzero(@(zLower)...
                        calcUDiff(uStar(ii), zLower,n(ii),e,phi, tax), [0 z(ii)]);
                    catch
                        ii
                        zL(ii)
                        phi
                        e
                    end
                end
            end
            if ~(isequal(flagV, ones(size(flagV))) || isreal(zL));
                [z(flagV~=1) flagV(flagV~=1) valV(flagV~=1)]
                e
                phi
            end
            assert (isequal(flagV, ones(size(flagV))));
            
        end
        
        zUStart = z+2.*(z-zL);
        [zU,val,flag] =fsolve(@(zUpper) calcUDiff(uStar, zUpper,n,e,phi, tax), zUStart, opt);
        
         % If fail, first solve point-by-point
        if ~(flag==1 || max(abs(val))<1e-10)
            valV = val;
            flagV = zeros(length(valV),1);
            for zz = 1:length(n)
                [zU(zz),valV(zz),flagV(zz)] =fzero(@(zUpper)...
                    calcUDiff(uStar(zz), zUpper,n(zz),e,phi, tax), zUStart(zz), opt);
            end
            troublePts = find(flagV~=1);
            if ~isempty(troublePts)
                troublePts = reshape(troublePts, 1, length(troublePts));
                for zz = troublePts
                    [zU(zz),valV(zz),flagV(zz)] =fzero(@(zUpper)...
                        calcUDiff(uStar(zz), zUpper,n(zz),e,phi, tax), [zUStart(zz) 1e8], opt);
                    if flagV(zz)~=1
                        flagV(zz)
                        zU(zz)
                        e
                        phi
                    end
                end
            end
            assert (isequal(flagV,ones(size(flagV))));
        end
    else
        zL = z;
        zU = z;
    end
end