package gov.cms.grouper.snf.component.v100.logic;

import gov.cms.grouper.snf.SnfContext;
import gov.cms.grouper.snf.SnfTables;
import gov.cms.grouper.snf.lego.Pair;
import gov.cms.grouper.snf.lego.SnfComparator;
import gov.cms.grouper.snf.lego.SnfUtils;
import gov.cms.grouper.snf.lego.SnfVersionImpl;
import gov.cms.grouper.snf.model.SnfDiagnosisCode;
import gov.cms.grouper.snf.model.reader.Rai300;
import gov.cms.grouper.snf.model.table.BasicRow;
import gov.cms.grouper.snf.model.table.NtaCmgRow;
import gov.cms.grouper.snf.model.table.NtaComorbidityRow;
import gov.cms.grouper.snf.util.ClaimInfo;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
 * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class="req">PDPM Payment
 * Component: NTA</a>
 */
public class NtaLogic extends SnfVersionImpl<String> {

  public enum ParenteralIvFeeding {
    HighIntensity(Arrays.asList(Rai300.K0510A2.name(), Rai300.K0710A2.name())),
    LowIntensity(
        Arrays.asList(Rai300.K0510A2.name(), Rai300.K0710A2.name(), Rai300.K0710B2.name()));

    private final SortedSet<String> items;

    ParenteralIvFeeding(List<String> items) {
      this.items = Collections.unmodifiableSortedSet(SnfUtils.toOrderedSet(items));
    }

    public SortedSet<String> getItems() {
      return this.items;
    }

  }


  public static final Function<SortedSet<String>, Predicate<NtaComorbidityRow>> FeedingFilter =
      (items) -> (row) -> items.equals(row.getMdsItems());

  public static final String ICD10_CODE_HIV = "B20";

  private final Set<String> secondaryDxNtaCategories;
  private final Set<String> assessmentNames;
  private final Supplier<Boolean> hasK0510A2;
  private final Supplier<Integer> k0710A2Value;
  private final Supplier<Integer> k0710B2Value;
  private final Predicate<NtaComorbidityRow> parenteralIvFeeding;
  private final Predicate<NtaComorbidityRow> step1$3AdditionalCommorbidities;
  private final ClaimInfo claim;

  public NtaLogic(ClaimInfo claim, List<SnfDiagnosisCode> secondaryDiagnosis) {
    super(claim.getVersion());

    this.claim = claim;
    this.assessmentNames =
        claim.getAssessmentNames((item) -> item.isCheck() && !item.getItem().equals("M0300D1"));
    // note on page 688
    this.assessmentNames.addAll(claim
        .getAssessmentNames((item) -> item.getValueInt() > 0 && item.getItem().equals("M0300D1")));

    this.secondaryDxNtaCategories = claim.getNtaCategories(secondaryDiagnosis.stream()
        .filter(snfDiagnosisCode -> !snfDiagnosisCode.getValue().equals(ICD10_CODE_HIV))
        .collect(Collectors.toList()));

    this.hasK0510A2 = () -> claim.isCheckedAndNotNull(Rai300.K0510A2);
    this.k0710A2Value = () -> claim.getAssessmentValue(Rai300.K0710A2);
    this.k0710B2Value = () -> claim.getAssessmentValue(Rai300.K0710B2);
    this.parenteralIvFeeding =
        step1$2ParenteralIvFeedingCondition(this.hasK0510A2, this.k0710A2Value, this.k0710B2Value);
    this.step1$3AdditionalCommorbidities =
        (row) -> step1$3AdditionalCommorbidities(this.assessmentNames, row);
  }

  /**
   * Determine whether the resident meets the criteria for the comorbidity: "Parenteral/IV Feeding –
   * High Intensity" or the comorbidity: "Parenteral/IV Feeding – Low Intensity"
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class=
   * "req">step1.2</a>
   */
  public Predicate<NtaComorbidityRow> step1$2ParenteralIvFeedingCondition(
      Supplier<Boolean> hasK0510A2, Supplier<Integer> k0710A2Value,
      Supplier<Integer> k0710B2Value) {
    ParenteralIvFeeding feeding = null;
    if (hasK0510A2.get()) {
      int k0710A2 = k0710A2Value.get();
      if (k0710A2 == 3) {
        feeding = ParenteralIvFeeding.HighIntensity;
      } else if (k0710A2 == 2 && k0710B2Value.get() == 2) {
        feeding = ParenteralIvFeeding.LowIntensity;
      }
    }

    Predicate<NtaComorbidityRow> result = (row) -> Boolean.FALSE;
    if (feeding != null) {
      result = NtaLogic.FeedingFilter.apply(feeding.getItems());
    }

    return SnfContext.trace(feeding, result);
  }

  /**
   * Determine whether the resident has any additional NTA-related comorbidities.
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class=
   * "req">step1.3</a>
   */
  public boolean step1$3AdditionalCommorbidities(Set<String> assessmentNames,
      NtaComorbidityRow row) {
    boolean result = false;
    for (String name : assessmentNames) {
      if (SnfUtils.containsAny(row.getMdsItems(), SnfUtils.toSet(name))
          && !row.getMdsItems().equals(ParenteralIvFeeding.LowIntensity.getItems())
          && !row.getMdsItems().equals(ParenteralIvFeeding.HighIntensity.getItems())) {
        result = true;
        SnfContext.trace(name);
        break;
      }
    }
    return result;
  }


  /**
   * Summarize the resident’s total NTA score from previous steps
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=688" class="req">step 2</a>
   *
   * @return Total NTA score
   */
  public int step2NtaScore(Set<String> secondaryDxNtaCategories,
      Predicate<NtaComorbidityRow> step1$3AdditionalCommorbidities,
      Predicate<NtaComorbidityRow> parenteralIvFeedingFilter) {

    final Predicate<NtaComorbidityRow> conditionService = (row) -> {
      boolean result = secondaryDxNtaCategories.contains(row.getConditionService());
      return result;
    };

    Predicate<NtaComorbidityRow> conditions = SnfUtils.or(Arrays.asList(conditionService,
        step1$3AdditionalCommorbidities, parenteralIvFeedingFilter));

    // Secondary Dx must be in secondary nutritional categories
    // Or: assessment has all the items in the assessmentNames
    conditions =
        conditions.and((row) -> BasicRow.getVersionSelector().apply(row, super.getVersion()));

    List<Integer> scores = SnfTables.selectAll(SnfTables.ntaComorbidityTableByConditionOfService,
        conditions, (row) -> {
          Integer point = SnfUtils.nullCheck(row, 0, row.getPoint());
          return point;
        });
    Integer sum = SnfComparator.sum(scores).intValue();

    return SnfContext.trace(sum);

  }

  /**
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=689" class="req">implement
   * step 3</a>
   *
   * @return NTA Case-Mix Group
   */
  public String step3NtaCaseMixGroup(int ntaScore) {
    NtaCmgRow row = SnfTables.get(SnfTables.ntaCmgTable, NtaCmgRow.verScoreFilter,
        Pair.of(super.getVersion(), ntaScore));
    String cmg = row.getCmg();

    return SnfContext.trace(cmg);
  }

  /**
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class="req">Implement
   * PDPM Payment Component: NTA</a>
   */

  @Override
  public String exec() {
    int ntaScore = this.step2NtaScore(this.secondaryDxNtaCategories,
        step1$3AdditionalCommorbidities, parenteralIvFeeding);

    String result = this.step3NtaCaseMixGroup(ntaScore);

    return SnfContext.trace(result);

  }

  public Set<String> getSecondaryDxNtaCategories() {
    return secondaryDxNtaCategories;
  }

  public Set<String> getAssessmentNames() {
    return assessmentNames;
  }

  public Supplier<Boolean> getHasK0510A2() {
    return hasK0510A2;
  }

  public Supplier<Integer> getK0710A2Value() {
    return k0710A2Value;
  }

  public Supplier<Integer> getK0710B2Value() {
    return k0710B2Value;
  }

  public Predicate<NtaComorbidityRow> getParenteralIvFeeding() {
    return parenteralIvFeeding;
  }

  public Predicate<NtaComorbidityRow> getStep1$3AdditionalCommorbidities() {
    return step1$3AdditionalCommorbidities;
  }

  public ClaimInfo getClaim() {
    return claim;
  }

}
