import React, { useCallback, useMemo } from "react"
import { range } from "lodash"
import { VirtualItem } from "@tanstack/react-virtual"
import { cn } from "@daybridge/cn"
import { DateTime } from "luxon"
import { DayOfWeek, startOfWeek } from "../_utils/weekStart"
import { VirtualizedView } from "./VirtualizedView"
import { useDisplayedDateRangeWrite } from "./context/DisplayedDateRangeContext"
import {
  useNavigationDateRead,
  useNavigationDateWrite,
} from "./context/NavigationDateContext"

export type GridVirtualizedViewItem = {
  key: string | number
  start: DateTime
  end: DateTime
  index: number
}

type GridVirtualizedViewProps = Omit<
  React.HTMLAttributes<HTMLDivElement>,
  "children"
> & {
  weekStartsOn?: DayOfWeek
  headerForDay?: (day: DateTime) => React.ReactNode
  count?: number
  rowHeight?: number
  children: (items: GridVirtualizedViewItem[]) => React.ReactNode
}

const HEADER_HEIGHT = 48

const GridVirtualizedViewFn = React.forwardRef(
  (
    props: GridVirtualizedViewProps,
    ref: React.ForwardedRef<HTMLDivElement>,
  ) => {
    const {
      count = 180,
      weekStartsOn = DayOfWeek.Monday,
      rowHeight = 200,
      headerForDay,
      children,
      ...rest
    } = props

    const { navigationDate } = useNavigationDateRead()
    const { setNavigationDate } = useNavigationDateWrite()
    const { setDateRange } = useDisplayedDateRangeWrite()

    // Headers
    const firstDayOfWeek = startOfWeek(navigationDate, weekStartsOn)
    const headerDays = useMemo(() => {
      return range(0, 7).map((i) => firstDayOfWeek.plus({ days: i }))
    }, [firstDayOfWeek])

    const items = useMemo(() => {
      return range(-count / 2, count / 2).map((i) => {
        return firstDayOfWeek.startOf("day").plus({ days: i * 7 })
      })
    }, [firstDayOfWeek, count])

    const itemSizes = useMemo(() => {
      return items.map(() => rowHeight)
    }, [items, rowHeight])

    const onLandOnIndex = useCallback(
      (index: number) => {
        if (index === count / 2) {
          // Already centered
          return
        }
        setNavigationDate(items[index])
      },
      // eslint-disable-next-line react-hooks/exhaustive-deps
      [items],
    )

    const onVisibleRangeChange = useCallback(
      (range: [number, number]) => {
        const firstDate = items[range[0]]
        const lastDate = items[range[1]]
        setDateRange?.([firstDate, lastDate])
      },
      [items, setDateRange],
    )

    const shouldForceSnap = useCallback(
      (index: number) => {
        const date = items[index]
        const nextDate = date.plus({ days: 7 })
        return datePeriodContainsMonthChange(date, nextDate) !== undefined
      },
      [items],
    )

    const render = useCallback(
      (i: VirtualItem[]) => {
        return (
          <>
            {headerForDay && (
              <div
                className={cn(
                  "sticky z-50 top-0 left-0",
                  "w-full",
                  "bg-surface",
                  "border-b border-tint",
                  "grid grid-cols-7",
                )}
                style={{
                  height: HEADER_HEIGHT,
                }}
              >
                {headerDays.map((day) => {
                  return headerForDay(day)
                })}
              </div>
            )}
            {children(
              i.map((item) => ({
                index: item.index,
                key: items[item.index].toISODate(),
                start: items[item.index],
                end: items[item.index].plus({ days: 7 }),
              })),
            )}
          </>
        )
      },
      [headerDays, headerForDay, items, children],
    )

    return (
      <VirtualizedView
        orientation="vertical"
        itemSizes={itemSizes}
        onLandOnIndex={onLandOnIndex}
        onVisibleRangeChange={onVisibleRangeChange}
        shouldForceSnap={shouldForceSnap}
        windowSize="100%"
        scrollPadding={HEADER_HEIGHT}
        ref={ref}
        {...rest}
      >
        {render}
      </VirtualizedView>
    )
  },
)
GridVirtualizedViewFn.displayName = "GridVirtualizedView"

export const GridVirtualizedView = React.memo(
  GridVirtualizedViewFn,
) as typeof GridVirtualizedViewFn

export const datePeriodContainsMonthChange = (
  start: DateTime,
  end: DateTime,
):
  | {
      newMonthDate: DateTime
      position: number
    }
  | undefined => {
  // Check if start date is at the beginning of a month
  if (start.day === 1) {
    return {
      newMonthDate: start,
      position: 0,
    }
  }

  if (start.month === end.minus({ days: 1 }).month) {
    return undefined
  }

  const newMonth = end.startOf("month")
  const position = newMonth.diff(start).as("days")

  return {
    newMonthDate: newMonth,
    position,
  }
}
