import * as d3 from "d3"

class ActivityAggregateGraphD3 {
  constructor(element, width, height, data) {
    this.element = element
    this.width = width
    this.height = height
    this.data = data
    this.margin = { top: 20, right: 10, bottom: 70, left: 40 }
    this.svg = null
    this.allStages = ["Deep", "Core", "REM", "Asleep", "In Bed", "Awake"]
    this.hiddenStages = ["In Bed"] // Keep track of hidden sleep levels

    this.createChart()
  }

  createChart() {
    d3.select(this.element).select("svg").remove()

    if (!this.data || Object.keys(this.data).length === 0) {
      console.warn("Empty data, skipping chart creation.")
      return
    }

    // Use dense mode if the number of dates exceeds 30.
    const denseMode = Object.keys(this.data).length > 30
    this.denseMode = denseMode

    const svg = d3
      .select(this.element)
      .append("svg")
      .attr("width", this.width)
      .attr("height", this.height)
      .style("user-select", "none")

    const chartGroup = svg.append("g").attr("transform", `translate(${this.margin.left},${this.margin.top})`)

    this.svg = chartGroup

    const innerWidth = this.width - this.margin.left - this.margin.right
    const innerHeight = this.height - this.margin.top - this.margin.bottom

    const parseDate = d3.timeParse("%Y-%m-%d")
    const formatDate = d3.timeFormat("%Y-%m-%d")

    // Extract dates and sort the data
    let dates = Object.keys(this.data)
      .map((d) => parseDate(d))
      .sort((a, b) => a - b)

    if (!dates.length) {
      console.warn("No valid dates found in data.")
      return
    }

    // Define the daily window: from 18:00 previous day to 18:00 current day
    function getPreviousDay(date) {
      const prev = new Date(date.getTime())
      prev.setDate(prev.getDate() - 1)
      return prev
    }

    const aggregatedData = dates.map((d) => {
      const currentDayStr = formatDate(d)
      const prevDay = getPreviousDay(d)
      const prevDayStr = formatDate(prevDay)

      const currentDayData = this.data[currentDayStr] || []
      const prevDayData = this.data[prevDayStr] || []

      // Combine relevant data from previous day after 18:00 and current day before 18:00
      const windowStart = new Date(d.getTime())
      windowStart.setHours(18, 0, 0, 0)
      windowStart.setDate(windowStart.getDate() - 1)

      const windowEnd = new Date(d.getTime())
      windowEnd.setHours(18, 0, 0, 0)

      // Filter sleep activities from these two sets
      const allActivities = [...prevDayData, ...currentDayData].filter((a) => a.type === "sleep")

      function parseTimeToDate(baseDate, timeStr) {
        const [H, M] = timeStr.split(":").map(Number)
        const dt = new Date(baseDate.getTime())
        dt.setHours(H, M, 0, 0)
        return dt
      }

      const filtered = []
      allActivities.forEach((a) => {
        // Determine the base date for this activity
        let base = currentDayData.includes(a) ? d : prevDay
        const startDt = parseTimeToDate(base, a.start_time)
        const endDt = parseTimeToDate(base, a.end_time)

        const activityStart = Math.max(startDt.getTime(), windowStart.getTime())
        const activityEnd = Math.min(endDt.getTime(), windowEnd.getTime())

        if (activityEnd > activityStart) {
          const startHours = (activityStart - windowStart.getTime()) / (1000 * 60 * 60)
          const endHours = (activityEnd - windowStart.getTime()) / (1000 * 60 * 60)
          filtered.push({
            level: a.level,
            start_hours: startHours,
            end_hours: endHours,
            duration: endHours - startHours,
          })
        }
      })

      // Aggregate by sleep stage
      const aggregation = {}
      this.allStages.forEach((s) => (aggregation[s] = 0))
      filtered.forEach((f) => {
        if (aggregation.hasOwnProperty(f.level)) {
          aggregation[f.level] += f.duration
        }
      })

      return {
        date: d,
        aggregated: aggregation,
      }
    })

    // Create class variable for updates
    this.aggregatedData = aggregatedData
    this.innerWidth = innerWidth
    this.innerHeight = innerHeight

    // Define scales
    this.xScale = d3
      .scaleBand()
      .domain(aggregatedData.map((d) => d.date))
      .range([0, innerWidth])
      .padding(0.1)

    this.colorScaleSleep = d3
      .scaleOrdinal()
      .domain(this.allStages)
      .range(["#0D1FF5", "#5E89FF", "#B7D3FF", "#8CE1FF", "#FFDB88", "#FF6D5C"])

    // Build & render axes
    this.buildAxes(chartGroup)

    // Create a group for layers (bars)
    this.layersGroup = chartGroup.append("g").attr("class", "layers")

    // Draw initial bars
    this.updateStackedBars()

    // Add legend (with toggles)
    this.addSleepLegend(svg, this.colorScaleSleep)

    // Add tooltip
    this.addTooltip(chartGroup, this.xScale)
  }

  buildAxes(chartGroup) {
    const dayAbbreviations = ["Su", "Mo", "Tu", "We", "Th", "Fr", "Sa"]

    // X Axis with conditional tick formatting:
    // In dense mode, only show a label for Mondays.
    const xAxis = d3.axisBottom(this.xScale).tickFormat((d) => {
      if (this.denseMode) {
        return d.getDay() === 1 ? dayAbbreviations[d.getDay()] : ""
      }
      return dayAbbreviations[d.getDay()]
    })

    const xAxisG = chartGroup.append("g").attr("transform", `translate(0, ${this.innerHeight})`).call(xAxis)

    xAxisG.selectAll("path, line").style("stroke", "#888888")
    xAxisG.selectAll("text").attr("dy", "1em").style("color", "#888888").style("text-anchor", "middle")

    // Add black rounded rectangle beneath weekend labels only if not in dense mode.
    if (!this.denseMode) {
      xAxisG.selectAll(".tick").each(function (d) {
        const day = d.getDay()
        if (day === 0 || day === 6) {
          const tick = d3.select(this)
          const rectHeight = 16
          const rectWidth = 18
          const roundRadius = 6

          tick
            .insert("rect", "text")
            .attr("x", -rectWidth / 2)
            .attr("y", rectHeight / 2)
            .attr("width", rectWidth)
            .attr("height", rectHeight)
            .attr("rx", roundRadius)
            .attr("ry", roundRadius)
            .style("fill", "black")

          tick.select("text").style("fill", "white")
        }
      })
    }

    // Additional date text below the weekday labels.
    xAxisG
      .selectAll(".tick")
      .append("text")
      .attr("dy", "35px")
      .attr("font-size", "10px")
      .attr("font-family", "sans-serif")
      .attr("fill", "#888888")
      .text((d) => {
        if (this.denseMode) {
          return d.getDay() === 1 ? d3.timeFormat("%m/%d")(d) : ""
        }
        return d3.timeFormat("%m/%d")(d)
      })

    // Draw dashed grid lines for each x tick
    xAxisG.selectAll(".tick").each((d) => {
      const tickX = this.xScale(d) + this.xScale.bandwidth() / 2 // Center of the tick
      // Set stroke darker for weekend ticks
      const strokeColor = d.getDay() === 0 || d.getDay() === 6 ? "#bbbbbb" : "#E7E7E7"
      chartGroup
        .append("line")
        .attr("x1", tickX)
        .attr("x2", tickX)
        .attr("y1", 0)
        .attr("y2", this.innerHeight)
        .style("stroke", strokeColor)
        .style("stroke-width", 1)
        .style("stroke-dasharray", "3,3") // Dashed line
    })

    // Y Axis
    this.yScale = d3.scaleLinear().range([this.innerHeight, 0])
    this.yAxisG = chartGroup.append("g")
  }

  // Calculate and draw stacked bars when the data or hidden stages change
  updateStackedBars() {
    const visibleStages = this.allStages.filter((stage) => !this.hiddenStages.includes(stage))
    const filteredStackData = this.aggregatedData.map((d) => {
      const obj = { date: d.date }
      visibleStages.forEach((s) => {
        obj[s] = d.aggregated[s]
      })
      return obj
    })

    const stackGenerator = d3.stack().keys(visibleStages)
    const series = stackGenerator(filteredStackData)

    // Update yScale domain based on the sum of visible stages
    const maxSleepHours = d3.max(this.aggregatedData, (d) =>
      visibleStages.reduce((sum, stage) => sum + (d.aggregated[stage] || 0), 0),
    )
    const yAxisMax = Math.max(maxSleepHours || 0, 10)
    this.yScale.domain([0, yAxisMax])

    // Update Y Axis
    const yAxis = d3.axisLeft(this.yScale).tickFormat((d) => `${d} hrs`)
    this.yAxisG.call(yAxis)
    this.yAxisG.selectAll("path, line").style("stroke", "#888888")
    this.yAxisG.selectAll("text").style("fill", "#888888")

    // Data join on layers
    const layerGroups = this.layersGroup.selectAll("g.layer").data(series, (d) => d.key)

    // EXIT old layers (for sleep levels that became hidden)
    layerGroups.exit().remove()

    // ENTER new layers
    const layerEnter = layerGroups
      .enter()
      .append("g")
      .attr("class", "layer")
      .attr("fill", (d) => this.colorScaleSleep(d.key))

    // MERGE layers
    const layersMerged = layerEnter.merge(layerGroups)
    const rects = layersMerged.selectAll("rect").data(
      (d) => d,
      (d) => d.data.date,
    )

    // EXIT old rects
    rects.exit().remove()

    // ENTER new rects
    const enteredRects = rects
      .enter()
      .append("rect")
      .attr("x", (d) => this.xScale(d.data.date))
      .attr("y", this.yScale(0))
      .attr("height", 0)
      .attr("width", this.xScale.bandwidth())

    // MERGE rects
    const allRects = enteredRects.merge(rects)

    // TRANSITION all rects to new positions
    allRects
      .transition()
      .duration(500)
      .ease(d3.easeCubicOut)
      .attr("y", (d) => this.yScale(d[1]))
      .attr("height", (d) => this.yScale(d[0]) - this.yScale(d[1]))
    this.rebindTooltipEvents()
  }

  // Add or remove the stage from hiddenStages, then update the chart
  toggleStage(stage) {
    const idx = this.hiddenStages.indexOf(stage)
    if (idx > -1) {
      this.hiddenStages.splice(idx, 1) // show
    } else {
      this.hiddenStages.push(stage) // hide
    }

    // Update stacked bars
    this.updateStackedBars()

    // Update legend color
    d3.select(this.element)
      .selectAll(`.legend-item-${stage.replace(/\s/g, "-")}`)
      .select("rect")
      .attr("fill", () => {
        return this.hiddenStages.includes(stage) ? "#e7e7e7" : this.colorScaleSleep(stage)
      })
  }

  addTooltip(svg, xScale) {
    this.tooltip = d3
      .select(this.element)
      .append("div")
      .attr("class", "activity-aggregate-tooltip")
      .style("position", "absolute")
      .style("padding", "4px 12px")
      .style("background", "#000")
      .style("border-radius", "12px")
      .style("color", "#fff")
      .style("font-size", "14px")
      .style("pointer-events", "none")
      .style("opacity", 0)
    this.rebindTooltipEvents()
  }

  rebindTooltipEvents() {
    this.layersGroup.selectAll("rect").on("mouseover mouseout", null)
    this.layersGroup
      .selectAll("rect")
      .on("mouseover", (event, d) => {
        d3.selectAll(".activity-aggregate-tooltip").style("opacity", 0)
        const stage = event.target.parentNode.__data__.key
        const duration = (d[1] - d[0]).toFixed(2) + " hrs"

        this.tooltip
          .html(`${stage} ${duration}`)
          .style("left", `${event.pageX + 10}px`)
          .style("top", `${event.pageY - 20}px`)
          .transition()
          .style("opacity", 1)
      })
      .on("mouseout", () => {
        this.tooltip.transition().style("opacity", 0)
      })
  }

  addSleepLegend(svg, colorScaleSleep) {
    const legendGroup = svg.append("g").attr("transform", `translate(${this.margin.left}, ${this.height - 20})`)
    const sleepData = colorScaleSleep.domain()

    legendGroup
      .append("text")
      .text("Sleep Stages")
      .style("font-size", "14px")
      .style("font-family", "sans-serif")
      .style("font-weight", "bold")
      .style("fill", "#888888")
      .style("user-select", "none")
      .attr("x", 0)
      .attr("y", 14)

    sleepData.forEach((level, i) => {
      const group = legendGroup
        .append("g")
        .attr("transform", `translate(${i * 70 + 95}, 0)`)
        .attr("class", `legend-item-${level.replace(/\s/g, "-")}`)
        .style("cursor", "pointer")
        .on("click", () => this.toggleStage(level))

      group
        .append("rect")
        .attr("width", 18)
        .attr("height", 18)
        .attr("rx", 5)
        .attr("ry", 5)
        .attr("fill", () => (this.hiddenStages.includes(level) ? "#e7e7e7" : colorScaleSleep(level)))
      group
        .append("text")
        .attr("x", 22)
        .attr("y", 14)
        .attr("fill", "#888888")
        .text(level)
        .style("user-select", "none")
        .style("font-size", "12px")
        .style("font-family", "sans-serif")
    })

    legendGroup
      .append("text")
      .text("Tip: Use legend icons to toggle sleep stages")
      .style("text-anchor", "end")
      .style("font-size", "14px")
      .style("font-family", "sans-serif")
      .style("fill", "#888888")
      .attr("x", this.width - this.margin.right - this.margin.left)
      .attr("y", 14)
  }
}

export default ActivityAggregateGraphD3
