import React from 'react';
import { keepPreviousData, useQuery } from '@tanstack/react-query';

import {
  PaginationState,
  useReactTable,
  getCoreRowModel,
  ColumnDef,
  flexRender,
  SortingState,
} from '@tanstack/react-table';

import Search from './Search';
import { Button } from 'src/components/ui/button';
import { getPaginationRange } from 'src/lib/utils';
import EmptyState from './EmptyState';
import SkeletonTable from '../Skeleton/TableSkeleton';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '../ui/select';

type TanstackTableProps<T> = {
  columns: ColumnDef<T>[];
  fetchDataFunc: (
    params: PaginationState & {
      filters?: Record<string, any>;
      sorting?: SortingState;
      search?: string;
    },
  ) => Promise<{ rows: T[]; rowCount: number }>;
  filters?: React.ComponentType<{
    filtersState: Record<string, any>;
    setFiltersState: any;
  }>;
  queryKey?: string;
};

const TanstackTable = <T extends {}>({
  columns,
  fetchDataFunc,
  filters: FiltersComponent,
  queryKey = '',
}: TanstackTableProps<T>) => {
  const [pagination, setPagination] = React.useState<PaginationState>({
    pageIndex: 0,
    pageSize: 10,
  });
  const [selectedFilters, setSelectedFilters] = React.useState<Record<string, any>>({});
  const [sorting, setSorting] = React.useState<SortingState>([]);
  const [search] = React.useState('');

  const dataQuery = useQuery({
    queryKey: [queryKey, pagination, selectedFilters, sorting, search],
    queryFn: () =>
      fetchDataFunc({
        ...pagination,
        filters: selectedFilters,
        sorting,
        search,
      }),
    placeholderData: keepPreviousData, // don't have 0 rows flash while changing pages/loading next page
  });

  const defaultData = React.useMemo(() => [], []);

  const table = useReactTable({
    data: dataQuery.data?.rows ?? defaultData,
    columns,
    // pageCount: dataQuery.data?.pageCount ?? -1, // you can now pass in `rowCount` instead of pageCount and `pageCount` will be calculated internally (new in v8.13.0)
    rowCount: dataQuery.data?.rowCount,
    state: {
      pagination,
    },
    onPaginationChange: setPagination,
    onSortingChange: setSorting,
    getCoreRowModel: getCoreRowModel(),
    manualPagination: true, // we're doing manual "server-side" pagination
    manualSorting: true,
    // getPaginationRowModel: getPaginationRowModel(), // If we are only doing manual pagination, we don't need this
    debugTable: true,
  });
  const totalPages = table.getPageCount();
  const currentPage = table.getState().pagination.pageIndex + 1; // 1-based index
  const paginationRange = getPaginationRange(currentPage, totalPages);
  const isEmpty = dataQuery.data?.rows?.length === 0;
  const isLoading = dataQuery.isLoading || dataQuery.isFetching;

  return (
    <div className="border p-4">
      <div className="mb-4 flex justify-between">
        <Search />
        {FiltersComponent && (
          <FiltersComponent filtersState={selectedFilters} setFiltersState={setSelectedFilters} />
        )}
      </div>
      {isLoading ? (
        <SkeletonTable rows={pagination.pageSize} />
      ) : isEmpty ? (
        <EmptyState />
      ) : (
        <>
          <div className="overflow-x-auto">
            <table className="min-w-full table-fixed border-collapse border-x border-t border-neutral-50">
              <thead>
                {table.getHeaderGroups().map((headerGroup) => (
                  <tr key={headerGroup.id}>
                    {headerGroup.headers.map((header) => {
                      return (
                        <th
                          className="border-b px-6 py-3 text-left text-sm font-medium leading-4 text-neutral-700"
                          key={header.id}
                          colSpan={header.colSpan}
                        >
                          {header.isPlaceholder ? null : (
                            <div>
                              {flexRender(header.column.columnDef.header, header.getContext())}
                            </div>
                          )}
                        </th>
                      );
                    })}
                  </tr>
                ))}
              </thead>
              <tbody>
                {table.getRowModel().rows.map((row) => {
                  return (
                    <tr key={row.id}>
                      {row.getVisibleCells().map((cell) => {
                        return (
                          <td
                            key={cell.id}
                            className="whitespace-no-wrap border-b px-6 py-4 text-sm leading-5"
                          >
                            {flexRender(cell.column.columnDef.cell, cell.getContext())}
                          </td>
                        );
                      })}
                    </tr>
                  );
                })}
              </tbody>
            </table>
          </div>
          <div className="h-2" />
          <div className="mt-4 flex flex-col items-center justify-between sm:flex-row">
            <div className="flex items-baseline">
              <div className="mr-2 hidden sm:block">
                Page {currentPage} of {totalPages}
              </div>
              <Select
                value={String(table.getState().pagination.pageSize)}
                onValueChange={(value) => {
                  table.setPageSize(Number(value));
                }}
              >
                <SelectTrigger className="w-[150px]">
                  <SelectValue placeholder={`Show ${table.getState().pagination.pageSize}`} />
                </SelectTrigger>
                <SelectContent>
                  {[10, 20, 30, 40, 50, 100].map((pageSize) => (
                    <SelectItem key={pageSize} value={String(pageSize)}>
                      Show {pageSize}
                    </SelectItem>
                  ))}
                </SelectContent>
              </Select>
            </div>
            {totalPages > 1 && (
              <div className="my-2 flex space-x-1 sm:my-0">
                {paginationRange?.map((page, index) => {
                  if (page === '...') {
                    return (
                      <span key={index} className="px-3 py-1">
                        {page}
                      </span>
                    );
                  }

                  return (
                    <button
                      key={index}
                      className={`px-3 py-1 ${
                        currentPage === page
                          ? 'rounded-md border border-primary-500 font-semibold text-primary-500'
                          : ''
                      }`}
                      onClick={() => table.setPageIndex(Number(page) - 1)}
                    >
                      {page}
                    </button>
                  );
                })}
              </div>
            )}
            <div>
              <Button
                className="rounded-md px-3 py-1 sm:mx-4"
                onClick={() => table.previousPage()}
                disabled={!table.getCanPreviousPage()}
                variant="outline"
              >
                Previous
              </Button>
              <Button
                className="rounded-md-black px-3 py-1"
                onClick={() => table.nextPage()}
                disabled={!table.getCanNextPage()}
                variant="outline"
              >
                Next
              </Button>
            </div>
          </div>
        </>
      )}
      {/* <pre>{JSON.stringify(table.getSelectedRowModel().rows, null, 2)}</pre>  */}
    </div>
  );
};

export default TanstackTable;
