// Adapted from https://github.com/adobe/react-spectrum/blob/main/packages/%40react-stately/tree/src/TreeCollection.ts

import { Collection, Node } from '@react-types/shared';
import { Key } from 'react';

import { GenericMenuItem } from './types';

export type MenuNode<T = GenericMenuItem, C = GenericMenuItem> = Node<T> & {
  firstInMenu: boolean;
  menuLevel: number;
  parentMenuItemKey?: Key;
  isDisabled?: boolean;
  childNodes: Iterable<MenuNode<C>>;
};

export class MenuCollection implements Collection<MenuNode> {
  private keyMap: Map<Key, MenuNode> = new Map();
  private iterable: Iterable<MenuNode>;
  private firstKey: Key;
  private lastKey: Key;

  constructor(nodes: Iterable<MenuNode>) {
    this.iterable = nodes;
    let visit = (node: MenuNode, firstInMenu: boolean, menuLevel: number, parentMenuItemKey: Key | null) => {
      node.firstInMenu = firstInMenu;
      node.menuLevel = menuLevel;
      node.parentMenuItemKey = parentMenuItemKey;
      node.isDisabled = node.value.isDisabled || node.value.type === 'divider';

      this.keyMap.set(node.key, node);

      if (node.childNodes) {
        let first = true;
        for (let child of node.childNodes) {
          visit(
            child,
            first,
            node.type === 'item' ? menuLevel + 1 : menuLevel,
            node.type === 'item' ? node.key : parentMenuItemKey,
          );
          first = false;
        }
      }
    };

    let first = true;
    for (let node of nodes) {
      visit(node, first, 1, null);
      first = false;
    }

    let last: MenuNode;
    let index = 0;
    for (let [key, node] of this.keyMap) {
      if (last) {
        last.nextKey = key;
        node.prevKey = last.key;
      } else {
        this.firstKey = key;
        node.prevKey = undefined;
      }

      if (node.type === 'item') {
        node.index = index++;
      }

      last = node;

      // Set nextKey as undefined since this might be the last node
      // If it isn't the last node, last.nextKey will properly set at start of new loop
      last.nextKey = undefined;
    }

    this.lastKey = last?.key;
  }

  *[Symbol.iterator]() {
    yield* this.iterable;
  }

  get size() {
    return this.keyMap.size;
  }

  at(pos: number) {
    let i = 0;
    for (const item of this) {
      if (i++ === pos) {
        return item;
      }
    }

    return;
  }

  getKeys() {
    return this.keyMap.keys();
  }

  getKeyBefore(key: Key) {
    let node = this.keyMap.get(key);
    return node ? node.prevKey : null;
  }

  getKeyAfter(key: Key) {
    let node = this.keyMap.get(key);
    return node ? node.nextKey : null;
  }

  getFirstKey() {
    return this.firstKey;
  }

  getLastKey() {
    return this.lastKey;
  }

  getItem(key: Key) {
    return this.keyMap.get(key);
  }
}
