import * as d3 from "d3"

class SleepFragmentationGraphD3 {
  constructor(element, width, height, data, color = "#FF6D5C") {
    this.element = element
    this.width = width
    this.height = height
    this.data = data
    this.color = color

    this.margin = { top: 15, right: 10, bottom: 70, left: 40 }

    this.svg = null

    this.createChart()
  }

  createChart() {
    // Remove any existing svg
    d3.select(this.element).select("svg").remove()
    d3.selectAll(".sleep-fragmentation-tooltip").style("opacity", 0)

    if (!this.data || Object.keys(this.data).length === 0) {
      console.warn("Empty data, skipping chart creation.")
      return
    }

    // Use dense mode if the number of dates exceeds 30.
    const denseMode = Object.keys(this.data).length > 30

    // Create the main svg
    const svg = d3
      .select(this.element)
      .append("svg")
      .attr("width", "100%")
      .attr("height", "100%")
      .attr("viewBox", `0 0 ${this.width} ${this.height}`)
      .style("user-select", "none")

    // Main chart group
    const chartGroup = svg.append("g").attr("transform", `translate(${this.margin.left},${this.margin.top})`)

    this.svg = chartGroup

    const innerWidth = this.width - this.margin.left - this.margin.right
    const innerHeight = this.height - this.margin.top - this.margin.bottom

    // Prepare data for fragmentation graph
    const parseDate = d3.timeParse("%Y-%m-%d")
    const formatDate = d3.timeFormat("%Y-%m-%d")

    let dates = Object.keys(this.data)
      .map((d) => parseDate(d))
      .sort((a, b) => a - b)

    if (!dates.length) {
      console.warn("No valid dates found in data.")
      return
    }

    function getPreviousDay(date) {
      const prev = new Date(date.getTime())
      prev.setDate(prev.getDate() - 1)
      return prev
    }

    // Aggregate from 6pm -> 6pm
    const aggregatedData = dates.map((d) => {
      const currentDayStr = formatDate(d)
      const prevDay = getPreviousDay(d)
      const prevDayStr = formatDate(prevDay)

      const currentDayData = this.data[currentDayStr] || []
      const prevDayData = this.data[prevDayStr] || []

      const windowStart = new Date(d.getTime())
      windowStart.setHours(18, 0, 0, 0)
      windowStart.setDate(windowStart.getDate() - 1)

      const windowEnd = new Date(d.getTime())
      windowEnd.setHours(18, 0, 0, 0)

      const allActivities = [...prevDayData, ...currentDayData].filter((a) => a.type === "sleep")

      function parseTimeToDate(baseDate, timeStr) {
        const [H, M] = timeStr.split(":").map(Number)
        const dt = new Date(baseDate.getTime())
        dt.setHours(H, M, 0, 0)
        return dt
      }

      const filtered = []
      allActivities.forEach((a) => {
        const baseDate = currentDayData.includes(a) ? d : prevDay
        const startDt = parseTimeToDate(baseDate, a.start_time)
        const endDt = parseTimeToDate(baseDate, a.end_time)

        const activityStart = Math.max(startDt.getTime(), windowStart.getTime())
        const activityEnd = Math.min(endDt.getTime(), windowEnd.getTime())

        if (activityEnd > activityStart) {
          const startHours = (activityStart - windowStart.getTime()) / (1000 * 60 * 60)
          const endHours = (activityEnd - windowStart.getTime()) / (1000 * 60 * 60)
          filtered.push({
            level: a.level,
            start_hours: startHours,
            end_hours: endHours,
            duration: endHours - startHours,
          })
        }
      })

      // Calculate awake count & total sleep
      let awakeCount = 0
      let totalSleep = 0
      filtered.forEach((f) => {
        if (f.level === "Awake") {
          awakeCount += 1
        }
        if (["Deep", "Core", "REM", "Asleep"].includes(f.level)) {
          totalSleep += f.duration
        }
      })

      // Mark fragmentation null if no data
      const fragmentation = totalSleep > 0 ? awakeCount / totalSleep : null

      return {
        date: d,
        awakeCount,
        totalSleep,
        fragmentation,
      }
    })

    // Scale and axes
    const xScale = d3
      .scaleBand()
      .domain(aggregatedData.map((d) => d.date))
      .range([0, innerWidth])
      .padding(0.1)

    // Add legend
    this.addSleepLegend(svg)

    // Get max fragmentation among valid data
    const validData = aggregatedData.filter((d) => d.fragmentation !== null)
    const maxFragmentation = validData.length ? d3.max(validData, (d) => d.fragmentation) : 0
    const yDomainMax = Math.max(1, maxFragmentation)
    const yScale = d3
      .scaleLinear()
      .domain([0, Math.max(4, yDomainMax)])
      .range([innerHeight, 0])

    // Create axes with conditional tick formatting:
    // In dense mode, only show a label (and additional date) for Mondays.
    const dayAbbreviations = ["Su", "Mo", "Tu", "We", "Th", "Fr", "Sa"]
    const xAxis = d3.axisBottom(xScale).tickFormat((d) => {
      return denseMode ? (d.getDay() === 1 ? dayAbbreviations[d.getDay()] : "") : dayAbbreviations[d.getDay()]
    })
    const yAxis = d3.axisLeft(yScale).tickFormat(d3.format(".2f"))

    // Horizontal grid lines
    chartGroup
      .append("g")
      .attr("class", "grid")
      .attr("color", "#E7E7E7")
      .attr("stroke-dasharray", "3,3")
      .call(d3.axisLeft(yScale).ticks(5).tickSize(-innerWidth).tickFormat(""))
      .call((g) => g.select(".domain").remove())

    // Vertical dashed grid lines
    const verticalGrid = chartGroup
      .append("g")
      .attr("class", "grid")
      .attr("color", "#E7E7E7")
      .attr("stroke-dasharray", "3,3")
      .attr("transform", `translate(0, ${innerHeight})`)
      .call(d3.axisBottom(xScale).tickSize(-innerHeight).tickFormat(""))
      .call((g) => g.select(".domain").remove())

    // Make weekend vertical grid lines darker
    verticalGrid.selectAll(".tick").each(function (d) {
      if (d && d.getDay && (d.getDay() === 0 || d.getDay() === 6)) {
        d3.select(this).select("line").style("stroke", "#bbbbbb")
      }
    })

    // X-axis
    const xAxisG = chartGroup.append("g").attr("transform", `translate(0, ${innerHeight})`).call(xAxis)
    xAxisG.selectAll("path, line").style("stroke", "#888888")
    xAxisG.selectAll("text").attr("dy", "1em").style("text-anchor", "middle").style("color", "#888888")

    // Highlight weekends only when NOT in dense mode
    if (!denseMode) {
      xAxisG.selectAll(".tick").each(function (d) {
        const day = d.getDay()
        if (day === 0 || day === 6) {
          const tick = d3.select(this)
          const rectHeight = 16
          const rectWidth = 18
          const roundRadius = 6

          tick
            .insert("rect", "text")
            .attr("x", -rectWidth / 2)
            .attr("y", rectHeight / 2)
            .attr("width", rectWidth)
            .attr("height", rectHeight)
            .attr("rx", roundRadius)
            .attr("ry", roundRadius)
            .style("fill", "black")

          tick.select("text").style("fill", "white")
        }
      })
    }

    // Additional date label
    xAxisG
      .selectAll(".tick")
      .append("text")
      .attr("dy", "35px")
      .attr("font-size", "10px")
      .attr("font-family", "sans-serif")
      .attr("fill", "#888888")
      .text((d) => (denseMode ? (d.getDay() === 1 ? d3.timeFormat("%m/%d")(d) : "") : d3.timeFormat("%m/%d")(d)))

    const yAxisG = chartGroup.append("g").call(yAxis)
    yAxisG.selectAll("path, line").style("stroke", "#888888")
    yAxisG.selectAll("text").style("fill", "#888888")

    // Y-axis label
    yAxisG
      .append("text")
      .attr("transform", "rotate(-90)")
      .attr("x", -innerHeight / 2)
      .attr("y", -35)
      .attr("dy", "1em")
      .style("text-anchor", "middle")
      .style("fill", "#888888")

    // Define gradient for area fill
    const defs = svg.append("defs")
    const gradient = defs
      .append("linearGradient")
      .attr("id", `line-gradient-fragmentation-${this.color}`)
      .attr("gradientUnits", "userSpaceOnUse")
      .attr("x1", 0)
      .attr("y1", yScale(0))
      .attr("x2", 0)
      .attr("y2", yScale(yDomainMax))

    gradient.append("stop").attr("offset", "0%").attr("stop-color", this.color).attr("stop-opacity", 0)
    gradient.append("stop").attr("offset", "100%").attr("stop-color", this.color).attr("stop-opacity", 0.3)

    // Draw area and line
    const area = d3
      .area()
      .defined((d) => d.fragmentation !== null)
      .x((d) => xScale(d.date) + xScale.bandwidth() / 2)
      .y0(yScale(0))
      .y1((d) => yScale(d.fragmentation))

    chartGroup
      .append("path")
      .datum(aggregatedData)
      .attr("fill", `url(#line-gradient-fragmentation-${this.color})`)
      .attr("d", area)

    const line = d3
      .line()
      .defined((d) => d.fragmentation !== null)
      .x((d) => xScale(d.date) + xScale.bandwidth() / 2)
      .y((d) => yScale(d.fragmentation))

    chartGroup
      .append("path")
      .datum(aggregatedData)
      .attr("fill", "none")
      .attr("stroke", this.color)
      .attr("stroke-width", 1)
      .attr("d", line)

    // Draw dashline for missing data
    let lastValidPoint = null
    aggregatedData.forEach((point) => {
      if (point.fragmentation !== null) {
        if (
          lastValidPoint &&
          aggregatedData
            .slice(aggregatedData.indexOf(lastValidPoint) + 1, aggregatedData.indexOf(point))
            .some((d) => d.fragmentation === null)
        ) {
          chartGroup
            .append("line")
            .attr("x1", xScale(lastValidPoint.date) + xScale.bandwidth() / 2)
            .attr("y1", yScale(lastValidPoint.fragmentation))
            .attr("x2", xScale(point.date) + xScale.bandwidth() / 2)
            .attr("y2", yScale(point.fragmentation))
            .attr("stroke", this.color)
            .attr("stroke-width", 1)
            .attr("stroke-linecap", "round")
            .attr("stroke-dasharray", "2,3")
            .style("opacity", 0.5)
        }
        lastValidPoint = point
      }
    })

    // Draw Data Points
    chartGroup
      .selectAll(".frag-circle")
      .data(aggregatedData.filter((d) => d.fragmentation !== null))
      .enter()
      .append("circle")
      .attr("class", "frag-circle")
      .attr("cx", (d) => xScale(d.date) + xScale.bandwidth() / 2)
      .attr("cy", (d) => yScale(d.fragmentation))
      .attr("r", 3)
      .attr("fill", this.color)

    // Only show numeric labels above each dot when NOT in dense mode.
    if (!denseMode) {
      chartGroup
        .selectAll(".frag-label")
        .data(aggregatedData.filter((d) => d.fragmentation !== null))
        .enter()
        .append("text")
        .attr("class", "frag-label")
        .attr("x", (d) => xScale(d.date) + xScale.bandwidth() / 2)
        .attr("y", (d) => yScale(d.fragmentation) - 6)
        .attr("text-anchor", "middle")
        .attr("font-size", "10px")
        .attr("font-family", "sans-serif")
        .attr("fill", this.color)
        .text((d) => d.fragmentation.toFixed(2))
    }

    // Tooltip & focus line
    const tooltip = d3
      .select(this.element)
      .append("div")
      .attr("class", "sleep-fragmentation-tooltip")
      .style("position", "absolute")
      .style("padding", "4px 12px")
      .style("background", "#000")
      .style("border-radius", "20px")
      .style("color", "#fff")
      .style("font-size", "14px")
      .style("pointer-events", "none")
      .style("opacity", 0)
      .style("white-space", "nowrap")

    // Focus line & circle
    const focusLine = chartGroup
      .append("line")
      .attr("class", "focusLine")
      .attr("stroke", "#000")
      .attr("stroke-width", 1)
      .attr("stroke-dasharray", "3,3")
      .style("opacity", 0)

    const focusCircle = chartGroup.append("circle").attr("r", 4).attr("fill", "#000").style("opacity", 0)

    chartGroup
      .append("rect")
      .attr("class", "overlay")
      .attr("width", innerWidth)
      .attr("height", innerHeight)
      .style("fill", "none")
      .style("pointer-events", "all")
      .on("mousemove", (event) =>
        this.handleMouseMove(event, aggregatedData, xScale, yScale, tooltip, focusLine, focusCircle),
      )
      .on("mouseout", () => {
        tooltip.style("opacity", 0)
        focusLine.style("opacity", 0)
        focusCircle.style("opacity", 0)
      })
  }

  handleMouseMove(event, data, xScale, yScale, tooltip, focusLine, focusCircle) {
    const [xPos] = d3.pointer(event)
    const xPositions = data.map((d) => xScale(d.date) + xScale.bandwidth() / 2)
    const index = xPositions.reduce((prevIndex, currX, currIndex) => {
      return Math.abs(currX - xPos) < Math.abs(xPositions[prevIndex] - xPos) ? currIndex : prevIndex
    }, 0)

    const d = data[index]
    if (!d) return

    const hasData = d.fragmentation !== null

    // Focus line
    focusLine
      .attr("x1", xScale(d.date) + xScale.bandwidth() / 2)
      .attr("x2", xScale(d.date) + xScale.bandwidth() / 2)
      .attr("y1", 0 - this.margin.top)
      .attr("y2", this.height - this.margin.bottom - 15)
      .style("opacity", 0.4)

    // Focus circle
    if (hasData) {
      focusCircle
        .attr("cx", xScale(d.date) + xScale.bandwidth() / 2)
        .attr("cy", yScale(d.fragmentation))
        .style("opacity", 1)
    } else {
      focusCircle.style("opacity", 0)
    }

    // Tooltip content
    const dateStr = d3.timeFormat("%m/%d")(d.date)
    const fragStr = hasData ? d.fragmentation.toFixed(2) : "?"
    const totalSleepStr = hasData ? d.totalSleep.toFixed(2) : "?"
    const awakeStr = hasData ? d.awakeCount : "?"

    tooltip
      .html(
        `<div style="text-align:left">
          <div>${dateStr}</div>
          <div>Awake: ${awakeStr} times</div>
          <div>Total Sleep: ${totalSleepStr} hrs</div>
          <div>Fragmentation: ${fragStr}</div>
        </div>`,
      )
      .style("left", `${event.pageX + 10}px`)
      .style("top", `${event.pageY - 40}px`)
      .style("opacity", 1)
  }

  addSleepLegend(svg) {
    const sleepLegend = svg.append("g").attr("transform", `translate(${this.margin.left}, ${this.height - 20})`)

    sleepLegend
      .append("text")
      .text("fragmentation = # of awake events / hours slept")
      .style("font-size", "14px")
      .style("font-family", "sans-serif")
      .style("color", "#888888")
      .attr("x", 24)
      .attr("y", 14)
    sleepLegend
      .append("rect")
      .attr("width", 18)
      .attr("height", 18)
      .attr("rx", 5)
      .attr("ry", 5)
      .attr("x", 0)
      .attr("y", 0)
      .attr("fill", "#FF6D5C")
  }
}

export default SleepFragmentationGraphD3
