import * as d3 from "d3"
import { round } from "lodash"

class ActivityAggregateGraphD3 {
  constructor(element, width, height, data) {
    this.element = element
    this.width = width
    this.height = height
    this.data = data
    this.margin = { top: 20, right: 10, bottom: 70, left: 40 }
    this.svg = null

    this.createChart()
  }

  createChart() {
    d3.select(this.element).select("svg").remove()

    if (!this.data || Object.keys(this.data).length === 0) {
      console.warn("Empty data, skipping chart creation.")
      return
    }

    const svg = d3.select(this.element).append("svg").attr("width", this.width).attr("height", this.height)

    const chartGroup = svg.append("g").attr("transform", `translate(${this.margin.left},${this.margin.top})`)

    this.svg = chartGroup

    const innerWidth = this.width - this.margin.left - this.margin.right
    const innerHeight = this.height - this.margin.top - this.margin.bottom

    const parseDate = d3.timeParse("%Y-%m-%d")
    const formatDate = d3.timeFormat("%Y-%m-%d")

    // Extract dates and sort them
    let dates = Object.keys(this.data)
      .map((d) => parseDate(d))
      .sort((a, b) => a - b)

    if (!dates.length) {
      console.warn("No valid dates found in data.")
      return
    }

    // Define the daily window: from 6 PM previous day to 6 PM current day
    function getPreviousDay(date) {
      const prev = new Date(date.getTime())
      prev.setDate(prev.getDate() - 1)
      return prev
    }

    const aggregatedData = dates.map((d) => {
      const currentDayStr = formatDate(d)
      const prevDay = getPreviousDay(d)
      const prevDayStr = formatDate(prevDay)

      const currentDayData = this.data[currentDayStr] || []
      const prevDayData = this.data[prevDayStr] || []

      // Combine relevant data from previous day after 18:00 and current day before 18:00
      const windowStart = new Date(d.getTime())
      windowStart.setHours(18, 0, 0, 0)
      windowStart.setDate(windowStart.getDate() - 1)

      const windowEnd = new Date(d.getTime())
      windowEnd.setHours(18, 0, 0, 0)

      // Filter sleep activities from these two sets
      const allActivities = [...prevDayData, ...currentDayData].filter((a) => a.type === "sleep")

      function parseTimeToDate(baseDate, timeStr) {
        const [H, M] = timeStr.split(":").map(Number)
        const dt = new Date(baseDate.getTime())
        dt.setHours(H, M, 0, 0)
        return dt
      }

      const filtered = []
      allActivities.forEach((a) => {
        // Determine the base date for this activity
        let base = currentDayData.includes(a) ? d : prevDay
        const startDt = parseTimeToDate(base, a.start_time)
        const endDt = parseTimeToDate(base, a.end_time)

        const activityStart = Math.max(startDt.getTime(), windowStart.getTime())
        const activityEnd = Math.min(endDt.getTime(), windowEnd.getTime())

        if (activityEnd > activityStart) {
          const startHours = (activityStart - windowStart.getTime()) / (1000 * 60 * 60)
          const endHours = (activityEnd - windowStart.getTime()) / (1000 * 60 * 60)
          filtered.push({
            level: a.level,
            start_hours: startHours,
            end_hours: endHours,
            duration: endHours - startHours,
          })
        }
      })

      // Aggregate by sleep stage
      const stages = ["REM", "Core", "Deep", "In Bed", "Awake", "Asleep"]
      const aggregation = {}
      stages.forEach((s) => (aggregation[s] = 0))
      filtered.forEach((f) => {
        if (aggregation.hasOwnProperty(f.level)) {
          aggregation[f.level] += f.duration
        }
      })

      return {
        date: d,
        aggregated: aggregation,
      }
    })

    // Stack the sleep stages in order (Deep -> Core -> REM -> Asleep -> In Bed -> Awake)
    const stages = ["Deep", "Core", "REM", "Asleep", "In Bed", "Awake"]
    const stackData = aggregatedData.map((d) => {
      const obj = { date: d.date }
      stages.forEach((s) => {
        obj[s] = d.aggregated[s]
      })
      return obj
    })

    // Create the stack generator
    const stackGenerator = d3.stack().keys(stages)
    const series = stackGenerator(stackData)

    // Define scales
    const xScale = d3
      .scaleBand()
      .domain(aggregatedData.map((d) => d.date))
      .range([0, innerWidth])
      .padding(0.1)

    // Define color scales
    const colorScaleSleep = d3
      .scaleOrdinal()
      .domain(["REM", "Core", "Deep", "In Bed", "Awake", "Asleep"])
      .range(["#B7D3FF", "#5E89FF", "#0D1FF5", "#FFDB88", "#FF6D5C", "#8CE1FF"])

    // X Axis
    const dayAbbreviations = ["Su", "Mo", "Tu", "We", "Th", "Fr", "Sa"]
    const xAxis = d3.axisBottom(xScale).tickFormat((d) => dayAbbreviations[d.getDay()])

    // Y Axis
    const maxSleepHours = d3.max(aggregatedData, (d) =>
      stages.reduce((sum, stage) => sum + (d.aggregated[stage] || 0), 0),
    )
    const yAxisMax = Math.max(maxSleepHours || 0, 10) // Default to 10 if no data
    const yScale = d3.scaleLinear().domain([0, yAxisMax]).range([innerHeight, 0])
    const yAxis = d3.axisLeft(yScale).tickFormat((d) => `${d} hrs`)

    // Append axes
    const xAxisG = chartGroup.append("g").attr("transform", `translate(0, ${innerHeight})`).call(xAxis)
    xAxisG.selectAll("path, line").style("stroke", "#888888")

    xAxisG.selectAll("text").attr("dy", "1em").style("color", "#888888").style("text-anchor", "middle")

    // Add black rounded rectangle beneath weekend labels
    xAxisG.selectAll(".tick").each(function (d) {
      const day = d.getDay()
      if (day === 0 || day === 6) {
        const tick = d3.select(this)
        const rectHeight = 16
        const rectWidth = 18
        const roundRadius = 6

        tick
          .insert("rect", "text")
          .attr("x", -rectWidth / 2)
          .attr("y", rectHeight / 2)
          .attr("width", rectWidth)
          .attr("height", rectHeight)
          .attr("rx", roundRadius)
          .attr("ry", roundRadius)
          .style("fill", "black")

        tick.select("text").style("fill", "white")
      }
    })

    // Additional date below the weekday labels
    xAxisG
      .selectAll(".tick")
      .append("text")
      .attr("dy", "35px")
      .attr("font-size", "10px")
      .attr("font-family", "sans-serif")
      .attr("fill", "#888888")
      .text((d) => d3.timeFormat("%m/%d")(d))

    // Draw dashed grid lines for each x tick
    xAxisG.selectAll(".tick").each(function (d) {
      const tick = d3.select(this)
      const tickX = xScale(d) + xScale.bandwidth() / 2 // Center of the tick

      chartGroup
        .append("line")
        .attr("x1", tickX)
        .attr("x2", tickX)
        .attr("y1", 0)
        .attr("y2", innerHeight)
        .style("stroke", "#E7E7E7")
        .style("stroke-width", 1)
        .style("stroke-dasharray", "3,3") // Dashed line
    })

    const yAxisG = chartGroup.append("g").call(yAxis)
    yAxisG.selectAll("path, line").style("stroke", "#888888")
    yAxisG.selectAll("text").style("fill", "#888888")

    // Draw stacked bars
    chartGroup
      .selectAll("g.layer")
      .data(series)
      .enter()
      .append("g")
      .attr("class", "layer")
      .attr("fill", (d) => colorScaleSleep(d.key))
      .selectAll("rect")
      .data((d) => d)
      .enter()
      .append("rect")
      .attr("x", (d) => xScale(d.data.date))
      .attr("y", (d) => yScale(d[1]))
      .attr("width", xScale.bandwidth())
      .attr("height", (d) => yScale(d[0]) - yScale(d[1]))

    this.addSleepLegend(svg, colorScaleSleep)
    this.addTooltip(chartGroup, xScale, yScale)
  }

  addTooltip(svg, xScale, yScale) {
    const tooltip = d3
      .select(this.element)
      .append("div")
      .attr("class", "activity-stack-tooltip")
      .style("position", "absolute")
      .style("padding", "4px 12px")
      .style("background", "#000")
      .style("border-radius", "12px")
      .style("color", "#fff")
      .style("font-size", "14px")
      .style("pointer-events", "none")
      .style("opacity", 0)

    // Tooltips for stacked bars:
    svg
      .selectAll("rect")
      .on("mouseover", (event, d) => {
        d3.selectAll(".activity-stack-tooltip").style("opacity", 0)
        const stage = event.target.parentNode.__data__.key
        const duration = (d[1] - d[0]).toFixed(2) + " hrs"
        tooltip
          .html(`${stage} ${duration}`)
          .style("left", `${event.pageX + 10}px`)
          .style("top", `${event.pageY - 20}px`)
          .transition()
          .style("opacity", 1)
      })
      .on("mouseout", () => tooltip.transition().style("opacity", 0))
  }

  addSleepLegend(svg, colorScaleSleep) {
    const sleepLegend = svg.append("g").attr("transform", `translate(${this.margin.left}, ${this.height - 20})`)
    const sleepData = colorScaleSleep.domain()

    sleepLegend
      .append("text")
      .text("Sleep Stages")
      .style("font-size", "14px")
      .style("font-family", "sans-serif")
      .style("font-weight", "bold")
      .style("fill", "#888888")
      .attr("x", 0)
      .attr("y", 14)

    sleepData.forEach((level, i) => {
      const legendGroup = sleepLegend.append("g").attr("transform", `translate(${i * 70 + 95}, 0)`)

      legendGroup
        .append("rect")
        .attr("width", 18)
        .attr("height", 18)
        .attr("rx", 5)
        .attr("ry", 5)
        .attr("fill", colorScaleSleep(level))

      legendGroup
        .append("text")
        .attr("x", 22)
        .attr("y", 14)
        .attr("fill", "#888888")
        .text(level)
        .style("font-size", "12px")
        .style("font-family", "sans-serif")
    })
  }
}

export default ActivityAggregateGraphD3
