diff --git a/webui/react/src/components/ComparisonView.test.mock.tsx b/webui/react/src/components/ComparisonView.test.mock.tsx index 908b9454530..9d5f5a54fd8 100644 --- a/webui/react/src/components/ComparisonView.test.mock.tsx +++ b/webui/react/src/components/ComparisonView.test.mock.tsx @@ -1,11 +1,14 @@ +import { useObservable } from 'micro-observables'; import React from 'react'; import { useGlasbey } from 'hooks/useGlasbey'; import { RunMetricData } from 'hooks/useMetrics'; +import { V1LocationType } from 'services/api-ts-sdk'; import { ExperimentWithTrial, Scale } from 'types'; import { generateTestRunData } from 'utils/tests/generateTestData'; import ComparisonView from './ComparisonView'; +import { FilterFormStore } from './FilterForm/components/FilterFormStore'; export const METRIC_DATA: RunMetricData = { data: { @@ -245,6 +248,7 @@ export const ExperimentComparisonViewWithMocks: React.FC = ({ onWidthChange, open, }: Props): JSX.Element => { + const tableFilters = useObservable(new FilterFormStore(V1LocationType.EXPERIMENT).asJsonString); const colorMap = useGlasbey(SELECTED_EXPERIMENTS.map((exp) => exp.experiment.id)); return ( = ({ initialWidth={200} open={open} projectId={1} + tableFilters={tableFilters} onWidthChange={onWidthChange}> {children} @@ -270,6 +275,7 @@ export const RunComparisonViewWithMocks: React.FC = ({ onWidthChange, open, }: Props): JSX.Element => { + const tableFilters = useObservable(new FilterFormStore(V1LocationType.RUN).asJsonString); const colorMap = useGlasbey(SELECTED_RUNS.map((run) => run.id)); return ( = ({ ? { selections: [], type: 'ONLY_IN' } : { selections: SELECTED_RUNS.map((run) => run.id), type: 'ONLY_IN' } } + tableFilters={tableFilters} onWidthChange={onWidthChange}> {children} diff --git a/webui/react/src/components/ComparisonView.tsx b/webui/react/src/components/ComparisonView.tsx index c5d5d018f06..610227af98c 100644 --- a/webui/react/src/components/ComparisonView.tsx +++ b/webui/react/src/components/ComparisonView.tsx @@ -15,13 +15,16 @@ import useMobile from 'hooks/useMobile'; import useScrollbarWidth from 'hooks/useScrollbarWidth'; import { TrialsComparisonTable } from 'pages/ExperimentDetails/TrialsComparisonModal'; import { searchExperiments, searchRuns } from 'services/api'; +import { V1ColumnType, V1LocationType } from 'services/api-ts-sdk'; import { ExperimentWithTrial, FlatRun, SelectionType, XOR } from 'types'; import handleError from 'utils/error'; import { getIdsFilter as getExperimentIdsFilter } from 'utils/experiment'; +import { combine } from 'utils/filterFormSet'; import { getIdsFilter as getRunIdsFilter } from 'utils/flatRun'; import CompareMetrics from './CompareMetrics'; import { INIT_FORMSET } from './FilterForm/components/FilterFormStore'; +import { FilterFormSet, Operator } from './FilterForm/components/type'; export const EMPTY_MESSAGE = 'No items selected.'; @@ -33,6 +36,8 @@ interface BaseProps { onWidthChange: (width: number) => void; fixedColumnsCount: number; projectId: number; + searchId?: number; + tableFilters: string; } type Props = XOR<{ experimentSelection: SelectionType }, { runSelection: SelectionType }> & @@ -132,6 +137,8 @@ const ComparisonView: React.FC = ({ projectId, experimentSelection, runSelection, + searchId, + tableFilters, }) => { const scrollbarWidth = useScrollbarWidth(); const hasPinnedColumns = fixedColumnsCount > 1; @@ -148,7 +155,10 @@ const ComparisonView: React.FC = ({ return NotLoaded; } try { - const filterFormSet = INIT_FORMSET; + const filterFormSet = + experimentSelection.type === 'ALL_EXCEPT' + ? (JSON.parse(tableFilters) as FilterFormSet) + : INIT_FORMSET; const filter = getExperimentIdsFilter(filterFormSet, experimentSelection); const response = await searchExperiments({ filter: JSON.stringify(filter), @@ -162,7 +172,7 @@ const ComparisonView: React.FC = ({ handleError(e, { publicSubject: 'Unable to fetch experiments for comparison' }); return NotLoaded; } - }, [experimentSelection, open]); + }, [experimentSelection, open, tableFilters]); const loadableSelectedRuns = useAsync(async () => { if ( @@ -172,12 +182,28 @@ const ComparisonView: React.FC = ({ ) { return NotLoaded; } - const filterFormSet = INIT_FORMSET; + const filterFormSet = + runSelection.type === 'ALL_EXCEPT' + ? (JSON.parse(tableFilters) as FilterFormSet) + : INIT_FORMSET; try { const filter = getRunIdsFilter(filterFormSet, runSelection); + if (searchId) { + // only display trials for search + const searchFilter = { + columnName: 'experimentId', + kind: 'field' as const, + location: V1LocationType.RUN, + operator: Operator.Eq, + type: V1ColumnType.NUMBER, + value: searchId, + }; + filter.filterGroup = combine(filter.filterGroup, 'and', searchFilter); + } const response = await searchRuns({ filter: JSON.stringify(filter), limit: SELECTION_LIMIT, + projectId, }); setIsSelectionLimitReached( !!response?.pagination?.total && response?.pagination?.total > SELECTION_LIMIT, @@ -187,7 +213,7 @@ const ComparisonView: React.FC = ({ handleError(e, { publicSubject: 'Unable to fetch runs for comparison' }); return NotLoaded; } - }, [open, runSelection]); + }, [open, projectId, runSelection, searchId, tableFilters]); const minWidths: [number, number] = useMemo(() => { return [fixedColumnsCount * MIN_COLUMN_WIDTH + scrollbarWidth, 100]; diff --git a/webui/react/src/components/ExperimentActionDropdown.tsx b/webui/react/src/components/ExperimentActionDropdown.tsx index 2536e23bcc8..2af2ca397ac 100644 --- a/webui/react/src/components/ExperimentActionDropdown.tsx +++ b/webui/react/src/components/ExperimentActionDropdown.tsx @@ -353,6 +353,7 @@ const ExperimentActionDropdown: React.FC = ({ /> { it('submits a valid create experiment request', async () => { await setup(); - await user.click(screen.getByRole('button', { name: CreateExperimentType.Fork })); + await user.click( + screen.getByRole('button', { name: RunActionCopyMap[CreateExperimentType.Fork] }), + ); expect(mockCreateExperiment).toHaveBeenCalled(); }); }); diff --git a/webui/react/src/components/ExperimentCreateModal.tsx b/webui/react/src/components/ExperimentCreateModal.tsx index 72cc8649d38..59ada006054 100644 --- a/webui/react/src/components/ExperimentCreateModal.tsx +++ b/webui/react/src/components/ExperimentCreateModal.tsx @@ -51,7 +51,7 @@ const ExperimentEntityCopyMap = { trial: 'trial', }; -const RunActionCopyMap = { +export const RunActionCopyMap = { [CreateExperimentType.ContinueTrial]: 'Continue Run', [CreateExperimentType.Fork]: 'Fork', }; @@ -361,7 +361,7 @@ const ExperimentCreateModalComponent = ({ form: idPrefix + FORM_ID, handleError, handler: handleSubmit, - text: type, + text: ExperimentActionCopyMap[type], }} title={titleLabel} onClose={handleModalClose}> diff --git a/webui/react/src/components/ExperimentMoveModal.tsx b/webui/react/src/components/ExperimentMoveModal.tsx index 07e380ea5c8..7fbe5868c92 100644 --- a/webui/react/src/components/ExperimentMoveModal.tsx +++ b/webui/react/src/components/ExperimentMoveModal.tsx @@ -14,14 +14,18 @@ import Link from 'components/Link'; import useFeature from 'hooks/useFeature'; import usePermissions from 'hooks/usePermissions'; import { paths } from 'routes/utils'; -import { moveExperiments } from 'services/api'; -import { V1BulkExperimentFilters } from 'services/api-ts-sdk'; +import { moveSearches } from 'services/api'; +import { V1MoveSearchesRequest } from 'services/api-ts-sdk'; import projectStore from 'stores/projects'; import workspaceStore from 'stores/workspaces'; -import { Project } from 'types'; +import { Project, SelectionType, XOR } from 'types'; import handleError from 'utils/error'; +import { getIdsFilter as getExperimentIdsFilter } from 'utils/experiment'; import { capitalize, pluralizer } from 'utils/string'; +import { INIT_FORMSET } from './FilterForm/components/FilterFormStore'; +import { FilterFormSet } from './FilterForm/components/type'; + const FORM_ID = 'move-experiment-form'; type FormInputs = { @@ -29,19 +33,21 @@ type FormInputs = { workspaceId?: number; }; -interface Props { - excludedExperimentIds?: Map; - experimentIds: number[]; - filters?: V1BulkExperimentFilters; +interface BaseProps { onSubmit?: (successfulIds?: number[]) => void; + selectionSize: number; sourceProjectId: number; sourceWorkspaceId?: number; } +type Props = BaseProps & + XOR<{ experimentIds: number[] }, { selection: SelectionType; tableFilters: string }>; + const ExperimentMoveModalComponent: React.FC = ({ - excludedExperimentIds, experimentIds, - filters, + selection, + selectionSize, + tableFilters, onSubmit, sourceProjectId, sourceWorkspaceId, @@ -54,8 +60,6 @@ const ExperimentMoveModalComponent: React.FC = ({ const projectId = Form.useWatch('projectId', form); const f_flat_runs = useFeature().isOn('flat_runs'); - const entityName = f_flat_runs ? 'searches' : 'experiments'; - useEffect(() => { setDisabled(workspaceId !== 1 && !projectId); }, [workspaceId, projectId, sourceProjectId, sourceWorkspaceId]); @@ -76,6 +80,14 @@ const ExperimentMoveModalComponent: React.FC = ({ } }, [workspaceId]); + // use plurals for indeterminate case + const pluralizerArgs = f_flat_runs + ? (['search', 'searches'] as const) + : (['experiment'] as const); + // we use apply instead of a direct call here because typescript errors when you spread a tuple into arguments + const plural = pluralizer.apply(null, [selectionSize, ...pluralizerArgs]); + const actionCopy = `Move ${capitalize(plural)}`; + const handleSubmit = async () => { if (workspaceId === sourceWorkspaceId && projectId === sourceProjectId) { openToast({ title: 'No changes to save.' }); @@ -84,16 +96,23 @@ const ExperimentMoveModalComponent: React.FC = ({ const values = await form.validateFields(); const projId = values.projectId ?? 1; - if (excludedExperimentIds?.size) { - filters = { ...filters, excludedExperimentIds: Array.from(excludedExperimentIds.keys()) }; + const moveSearchesArgs: V1MoveSearchesRequest = { + destinationProjectId: projId, + sourceProjectId, + }; + + if (tableFilters !== undefined) { + const filterFormSet = + selection.type === 'ALL_EXCEPT' + ? (JSON.parse(tableFilters) as FilterFormSet) + : INIT_FORMSET; + const filter = getExperimentIdsFilter(filterFormSet, selection); + moveSearchesArgs.filter = JSON.stringify(filter); + } else { + moveSearchesArgs.searchIds = experimentIds; } - const results = await moveExperiments({ - destinationProjectId: projId, - experimentIds, - filters, - projectId: sourceProjectId, - }); + const results = await moveSearches(moveSearchesArgs); onSubmit?.(results.successful); @@ -106,19 +125,19 @@ const ExperimentMoveModalComponent: React.FC = ({ if (numSuccesses === 0 && numFailures === 0) { openToast({ - description: `No selected ${entityName} were eligible for moving`, - title: `No eligible ${entityName}`, + description: `No selected ${plural} were eligible for moving`, + title: `No eligible ${plural}`, }); } else if (numFailures === 0) { openToast({ closeable: true, - description: `${results.successful.length} ${entityName} moved to project ${destinationProjectName}`, + description: `${results.successful.length} ${pluralizer.apply(null, [results.successful.length, ...pluralizerArgs])} moved to project ${destinationProjectName}`, link: View Project, title: 'Move Success', }); } else if (numSuccesses === 0) { openToast({ - description: `Unable to move ${numFailures} ${entityName}`, + description: `Unable to move ${numFailures} ${pluralizer.apply(null, [numFailures, ...pluralizerArgs])}`, severity: 'Warning', title: 'Move Failure', }); @@ -127,7 +146,7 @@ const ExperimentMoveModalComponent: React.FC = ({ closeable: true, description: `${numFailures} out of ${ numFailures + numSuccesses - } eligible ${entityName} failed to move + } eligible ${plural} failed to move to project ${destinationProjectName}`, link: View Project, severity: 'Warning', @@ -142,15 +161,6 @@ const ExperimentMoveModalComponent: React.FC = ({ form.setFieldValue('workspaceId', sourceWorkspaceId ?? 1); }, [form, sourceProjectId, sourceWorkspaceId]); - // use plurals for indeterminate case - const entityCount = filters !== undefined ? 2 : experimentIds.length; - const pluralizerArgs = f_flat_runs - ? (['search', 'searches'] as const) - : (['experiment'] as const); - // we use apply instead of a direct call here because typescript errors when you spread a tuple into arguments - const plural = pluralizer.apply(null, [entityCount, ...pluralizerArgs]); - const actionCopy = `Move ${capitalize(plural)}`; - return ( ; labelSingular: string; labelPlural: string; + onActualSelectAll?: () => void; + onClearSelect?: () => void; + pageSize?: number; selectedCount: number; } @@ -17,6 +21,9 @@ const LoadableCount: React.FC = ({ total, labelPlural, labelSingular, + onActualSelectAll, + onClearSelect, + pageSize = 20, selectedCount, }: Props) => { const isMobile = useMobile(); @@ -41,11 +48,37 @@ const LoadableCount: React.FC = ({ }); }, [labelPlural, labelSingular, total, selectedCount]); + const actualSelectAll = useMemo(() => { + return Loadable.match(total, { + _: () => null, + Loaded: (loadedTotal) => { + if (onActualSelectAll && selectedCount >= pageSize && selectedCount < loadedTotal) { + return ( + + ); + } else if (onClearSelect && (selectedCount >= pageSize || selectedCount === loadedTotal)) { + return ( + + ); + } + + return null; + }, + }); + }, [labelPlural, onActualSelectAll, onClearSelect, pageSize, selectedCount, total]); + if (!isMobile) { return ( - - {selectionLabel} - + <> + + {selectionLabel} + + {actualSelectAll} + ); } else { return null; diff --git a/webui/react/src/components/RunActionDropdown.tsx b/webui/react/src/components/RunActionDropdown.tsx index a553e9d0181..e782b73c0a9 100644 --- a/webui/react/src/components/RunActionDropdown.tsx +++ b/webui/react/src/components/RunActionDropdown.tsx @@ -207,7 +207,8 @@ const RunActionDropdown: React.FC = ({ const shared = ( onComplete?.(FlatRunAction.Move, run.id)} diff --git a/webui/react/src/components/RunFilterInterstitialModalComponent.test.tsx b/webui/react/src/components/RunFilterInterstitialModalComponent.test.tsx index 7a05a4fdae1..9cefe4dae02 100644 --- a/webui/react/src/components/RunFilterInterstitialModalComponent.test.tsx +++ b/webui/react/src/components/RunFilterInterstitialModalComponent.test.tsx @@ -111,7 +111,7 @@ describe('RunFilterInterstitialModalComponent', () => { // TODO: is there a better way to test these expectations? expect(filterFormSet.showArchived).toBeTruthy(); - const [, , idFilter] = filterFormSet.filterGroup.children; + const [, idFilter] = filterFormSet.filterGroup.children; for (const child of expectedFilterGroup.children) { expect(filterFormSet.filterGroup.children).toContainEqual(child); } @@ -148,7 +148,7 @@ describe('RunFilterInterstitialModalComponent', () => { const filterFormSet = JSON.parse(filterFormSetString || ''); expect(filterFormSet.showArchived).toBe(false); - const idFilters = filterFormSet.filterGroup.children || []; + const idFilters = filterFormSet.filterGroup.children[0].children || []; expect(idFilters.every((f: FormField) => f.operator === '=')).toBe(true); expect(idFilters.map((f: FormField) => f.value)).toEqual(expectedSelection); }); diff --git a/webui/react/src/components/RunFilterInterstitialModalComponent.tsx b/webui/react/src/components/RunFilterInterstitialModalComponent.tsx index 287b3f4a320..94850975e40 100644 --- a/webui/react/src/components/RunFilterInterstitialModalComponent.tsx +++ b/webui/react/src/components/RunFilterInterstitialModalComponent.tsx @@ -1,5 +1,5 @@ import { useModal } from 'hew/Modal'; -import { Failed, NotLoaded } from 'hew/utils/loadable'; +import { Failed, Loadable, NotLoaded } from 'hew/utils/loadable'; import { forwardRef, useCallback, useImperativeHandle, useRef, useState } from 'react'; import { FilterFormSetWithoutId } from 'components/FilterForm/components/type'; @@ -74,11 +74,13 @@ export const RunFilterInterstitialModalComponent = forwardRef ({ close, open })); - const selectionHasSearchRuns = useAsync( + const selectionHasSearchRuns: Loadable = useAsync( async (canceler) => { if (!isOpen) return NotLoaded; const mergedCanceler = mergeAbortControllers(canceler, closeController.current); - const filterWithSingleFilter = combine(filterFormSet.filterGroup, 'and', { + + const filter: FilterFormSetWithoutId = getIdsFilter(filterFormSet, selection); + filter.filterGroup = combine(filter.filterGroup, 'and', { columnName: 'searcherType', kind: 'field', location: 'LOCATION_TYPE_RUN', @@ -86,13 +88,6 @@ export const RunFilterInterstitialModalComponent = forwardRef { +const SearchTensorBoardModal = ({ workspaceId, selectedSearches }: Props): JSX.Element => { const handleSubmit = async () => { - const managedExperimentIds = selectedExperiments - .filter((exp) => !exp.unmanaged) - .map((exp) => exp.id); + const managedSearchIds = selectedSearches.filter((exp) => !exp.unmanaged).map((exp) => exp.id); openCommandResponse( - await openOrCreateTensorBoard({ experimentIds: managedExperimentIds, filters, workspaceId }), + await openOrCreateTensorBoardSearches({ + searchIds: managedSearchIds, + workspaceId, + }), ); }; @@ -42,4 +37,4 @@ const ExperimentTensorBoardModal = ({ ); }; -export default ExperimentTensorBoardModal; +export default SearchTensorBoardModal; diff --git a/webui/react/src/components/Searches/Searches.tsx b/webui/react/src/components/Searches/Searches.tsx index 254927b80b3..cf23e84a574 100644 --- a/webui/react/src/components/Searches/Searches.tsx +++ b/webui/react/src/components/Searches/Searches.tsx @@ -1,4 +1,3 @@ -import { CompactSelection, GridSelection } from '@glideapps/glide-data-grid'; import { isLeft } from 'fp-ts/lib/Either'; import Column from 'hew/Column'; import { @@ -11,15 +10,7 @@ import { MIN_COLUMN_WIDTH, MULTISELECT, } from 'hew/DataGrid/columns'; -import DataGrid, { - DataGridHandle, - HandleSelectionChangeType, - RangelessSelectionType, - SelectionType, - Sort, - validSort, - ValidSort, -} from 'hew/DataGrid/DataGrid'; +import DataGrid, { DataGridHandle, Sort, validSort, ValidSort } from 'hew/DataGrid/DataGrid'; import { MenuItem } from 'hew/Dropdown'; import Icon from 'hew/Icon'; import Link from 'hew/Link'; @@ -56,6 +47,7 @@ import { useDebouncedSettings } from 'hooks/useDebouncedSettings'; import { useGlasbey } from 'hooks/useGlasbey'; import useMobile from 'hooks/useMobile'; import usePolling from 'hooks/usePolling'; +import useSelection from 'hooks/useSelection'; import { useSettings } from 'hooks/useSettings'; import { useTypedParams } from 'hooks/useTypedParams'; import { paths } from 'routes/utils'; @@ -75,7 +67,6 @@ import { Project, ProjectColumn, RunState, - SelectionType as SelectionState, } from 'types'; import handleError from 'utils/error'; import { getProjectExperimentForExperimentItem } from 'utils/experiment'; @@ -97,8 +88,6 @@ interface Props { project: Project; } -type ExperimentWithIndex = { index: number; experiment: BulkExperimentItem }; - const BANNED_FILTER_COLUMNS = new Set(['searcherMetricsVal']); const BANNED_SORT_COLUMNS = new Set(['tags', 'searcherMetricsVal']); @@ -183,6 +172,15 @@ const Searches: React.FC = ({ project }) => { const isMobile = useMobile(); const { openToast } = useToast(); + const { selectionSize, dataGridSelection, handleSelectionChange, isRangeSelected } = useSelection( + { + records: experiments.map((loadable) => loadable.map((exp) => exp.experiment)), + selection: settings.selection, + total, + updateSettings, + }, + ); + const handlePinnedColumnsCountChange = useCallback( (newCount: number) => updateSettings({ pinnedColumnsCount: newCount }), [updateSettings], @@ -248,34 +246,22 @@ const Searches: React.FC = ({ project }) => { return []; }, [settings.selection]); - const loadedSelectedExperimentIds = useMemo(() => { - const selectedMap = new Map(); + const loadedExperimentIdMap = useMemo(() => { + const experimentMap = new Map(); + if (isLoadingSettings) { - return selectedMap; + return experimentMap; } - const selectedIdSet = new Set(allSelectedExperimentIds); + experiments.forEach((e, index) => { - Loadable.forEach(e, ({ experiment }) => { - if (selectedIdSet.has(experiment.id)) { - selectedMap.set(experiment.id, { experiment, index }); - } + Loadable.forEach(e, (experiment) => { + experimentMap.set(experiment.experiment.id, { experiment, index }); }); }); - return selectedMap; - }, [isLoadingSettings, allSelectedExperimentIds, experiments]); + return experimentMap; + }, [experiments, isLoadingSettings]); - const selection = useMemo(() => { - let rows = CompactSelection.empty(); - loadedSelectedExperimentIds.forEach((info) => { - rows = rows.add(info.index); - }); - return { - columns: CompactSelection.empty(), - rows, - }; - }, [loadedSelectedExperimentIds]); - - const colorMap = useGlasbey([...loadedSelectedExperimentIds.keys()]); + const colorMap = useGlasbey([...loadedExperimentIdMap.keys()]); const experimentFilters = useMemo(() => { const filters: V1BulkExperimentFilters = { @@ -399,71 +385,6 @@ const Searches: React.FC = ({ project }) => { }; }, [canceler, stopPolling]); - const rowRangeToIds = useCallback( - (range: [number, number]) => { - const slice = experiments.slice(range[0], range[1]); - return Loadable.filterNotLoaded(slice).map(({ experiment }) => experiment.id); - }, - [experiments], - ); - - const handleSelectionChange: HandleSelectionChangeType = useCallback( - (selectionType: SelectionType | RangelessSelectionType, range?: [number, number]) => { - let newSettings: SelectionState = { ...settings.selection }; - - switch (selectionType) { - case 'add': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.delete(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.add(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'add-all': - newSettings = { - exclusions: [], - type: 'ALL_EXCEPT' as const, - }; - - break; - case 'remove': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.add(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.delete(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'remove-all': - newSettings = DEFAULT_SELECTION; - - break; - case 'set': - if (!range) return; - newSettings = { - ...DEFAULT_SELECTION, - selections: Array.from(rowRangeToIds(range)), - }; - - break; - } - - updateSettings({ selection: newSettings }); - }, - [rowRangeToIds, settings.selection, updateSettings], - ); - const handleActionComplete = useCallback(async () => { /** * Deselect selected rows since their states may have changed where they @@ -639,7 +560,7 @@ const Searches: React.FC = ({ project }) => { const gridColumns = [...STATIC_COLUMNS, ...columnsIfLoaded] .map((columnName) => { if (columnName === MULTISELECT) { - return (columnDefs[columnName] = defaultSelectionColumn(selection.rows, false)); + return (columnDefs[columnName] = defaultSelectionColumn(dataGridSelection.rows, false)); } if (!Loadable.isLoaded(projectColumnsMap)) { @@ -712,39 +633,34 @@ const Searches: React.FC = ({ project }) => { columnsIfLoaded, appTheme, isDarkMode, - selection.rows, + dataGridSelection.rows, users, ]); + const handleActualSelectAll = useCallback(() => { + handleSelectionChange?.('add-all'); + }, [handleSelectionChange]); + + const handleClearSelect = useCallback(() => { + handleSelectionChange?.('remove-all'); + }, [handleSelectionChange]); + + const handleHeaderClick = useCallback( + (columnId: string): void => { + if (columnId === MULTISELECT) { + if (isRangeSelected([0, settings.pageLimit])) { + handleSelectionChange?.('remove', [0, settings.pageLimit]); + } else { + handleSelectionChange?.('add', [0, settings.pageLimit]); + } + } + }, + [handleSelectionChange, isRangeSelected, settings.pageLimit], + ); + const getHeaderMenuItems = (columnId: string, colIdx: number): MenuItem[] => { if (columnId === MULTISELECT) { - const items: MenuItem[] = [ - settings.selection.type === 'ALL_EXCEPT' || settings.selection.selections.length > 0 - ? { - key: 'select-none', - label: 'Clear selected', - onClick: () => { - handleSelectionChange?.('remove-all'); - }, - } - : null, - ...[5, 10, 25].map((n) => ({ - key: `select-${n}`, - label: `Select first ${n}`, - onClick: () => { - handleSelectionChange?.('set', [0, n]); - dataGridRef.current?.scrollToTop(); - }, - })), - { - key: 'select-all', - label: 'Select all', - onClick: () => { - handleSelectionChange?.('add', [0, settings.pageLimit]); - }, - }, - ]; - return items; + return []; } const column = Loadable.getOrElse([], projectColumns).find((c) => c.column === columnId); if (!column) { @@ -875,14 +791,20 @@ const Searches: React.FC = ({ project }) => { isOpenFilter={isOpenFilter} labelPlural="searches" labelSingular="search" + pageSize={settings.pageLimit} project={project} projectColumns={projectColumns} rowHeight={globalSettings.rowHeight} selectedExperimentIds={allSelectedExperimentIds} + selection={settings.selection} + selectionSize={selectionSize} sorts={sorts} + tableFilterString={filtersString} total={total} onActionComplete={handleActionComplete} onActionSuccess={handleActionSuccess} + onActualSelectAll={handleActualSelectAll} + onClearSelect={handleClearSelect} onIsOpenFilterChange={handleIsOpenFilterChange} onRowHeightChange={handleRowHeightChange} onSortChange={handleSortChange} @@ -938,12 +860,13 @@ const Searches: React.FC = ({ project }) => { ); }} rowHeight={rowHeightMap[globalSettings.rowHeight as RowHeight]} - selection={selection} + selection={dataGridSelection} sorts={sorts} staticColumns={STATIC_COLUMNS} onColumnResize={handleColumnWidthChange} onColumnsOrderChange={handleColumnsOrderChange} onContextMenuComplete={handleContextMenuComplete} + onHeaderClicked={handleHeaderClick} onPinnedColumnsCountChange={handlePinnedColumnsCountChange} onSelectionChange={handleSelectionChange} /> diff --git a/webui/react/src/components/TableActionBar.tsx b/webui/react/src/components/TableActionBar.tsx index 1196ff9ca5c..dff835730fd 100644 --- a/webui/react/src/components/TableActionBar.tsx +++ b/webui/react/src/components/TableActionBar.tsx @@ -15,27 +15,28 @@ import BatchActionConfirmModalComponent from 'components/BatchActionConfirmModal import ColumnPickerMenu from 'components/ColumnPickerMenu'; import ExperimentMoveModalComponent from 'components/ExperimentMoveModal'; import ExperimentRetainLogsModalComponent from 'components/ExperimentRetainLogsModal'; -import ExperimentTensorBoardModal from 'components/ExperimentTensorBoardModal'; import { FilterFormStore } from 'components/FilterForm/components/FilterFormStore'; import TableFilter from 'components/FilterForm/TableFilter'; import MultiSortMenu from 'components/MultiSortMenu'; import { OptionsMenu, RowHeight } from 'components/OptionsMenu'; import { defaultProjectSettings } from 'components/Searches/Searches.settings'; +import SearchTensorBoardModal from 'components/SearchTensorBoardModal'; import useMobile from 'hooks/useMobile'; import usePermissions from 'hooks/usePermissions'; import { defaultExperimentColumns } from 'pages/F_ExpList/expListColumns'; import { - activateExperiments, - archiveExperiments, - cancelExperiments, - deleteExperiments, + archiveSearches, + cancelSearches, + deleteSearches, getExperiments, - killExperiments, - openOrCreateTensorBoard, - pauseExperiments, - unarchiveExperiments, + killSearches, + openOrCreateTensorBoardSearches, + pauseSearches, + resumeSearches, + unarchiveSearches, } from 'services/api'; import { V1LocationType } from 'services/api-ts-sdk'; +import { SearchBulkActionParams } from 'services/types'; import { BulkActionResult, BulkExperimentItem, @@ -43,16 +44,19 @@ import { Project, ProjectColumn, ProjectExperiment, + SelectionType, } from 'types'; import handleError, { ErrorLevel } from 'utils/error'; import { canActionExperiment, getActionsForExperimentsUnion, + getIdsFilter, getProjectExperimentForExperimentItem, } from 'utils/experiment'; import { capitalizeWord } from 'utils/string'; import { openCommandResponse } from 'utils/wait'; +import { FilterFormSet } from './FilterForm/components/type'; import LoadableCount from './LoadableCount'; import css from './TableActionBar.module.scss'; @@ -93,6 +97,8 @@ interface Props { isOpenFilter: boolean; onActionComplete?: () => Promise; onActionSuccess?: (action: BatchAction, successfulIds: number[]) => void; + onActualSelectAll?: () => void; + onClearSelect?: () => void; onComparisonViewToggle?: () => void; onHeatmapToggle?: (heatmapOn: boolean) => void; onIsOpenFilterChange?: (value: boolean) => void; @@ -100,10 +106,13 @@ interface Props { onSortChange?: (sorts: Sort[]) => void; onVisibleColumnChange?: (newColumns: string[], pinnedCount?: number) => void; onHeatmapSelectionRemove?: (id: string) => void; + pageSize?: number; project: Project; projectColumns: Loadable; rowHeight: RowHeight; selectedExperimentIds: number[]; + selection: SelectionType; + selectionSize: number; sorts: Sort[]; pinnedColumnsCount?: number; total: Loadable; @@ -113,17 +122,21 @@ interface Props { bannedFilterColumns?: Set; bannedSortColumns?: Set; entityCopy?: string; + tableFilterString: string; } const TableActionBar: React.FC = ({ compareViewOn, formStore, + tableFilterString, heatmapBtnVisible, heatmapOn, initialVisibleColumns, isOpenFilter, onActionComplete, onActionSuccess, + onActualSelectAll, + onClearSelect, onComparisonViewToggle, onHeatmapToggle, onIsOpenFilterChange, @@ -131,6 +144,7 @@ const TableActionBar: React.FC = ({ onSortChange, onHeatmapSelectionRemove, onVisibleColumnChange, + pageSize, project, projectColumns, rowHeight, @@ -144,14 +158,16 @@ const TableActionBar: React.FC = ({ bannedFilterColumns, bannedSortColumns, entityCopy, + selectionSize, + selection, }) => { const permissions = usePermissions(); const [batchAction, setBatchAction] = useState(); const BatchActionConfirmModal = useModal(BatchActionConfirmModalComponent); const ExperimentMoveModal = useModal(ExperimentMoveModalComponent); const ExperimentRetainLogsModal = useModal(ExperimentRetainLogsModalComponent); - const { Component: ExperimentTensorBoardModalComponent, open: openExperimentTensorBoardModal } = - useModal(ExperimentTensorBoardModal); + const { Component: SearchTensorBoardModalComponent, open: openSearchTensorBoardModal } = + useModal(SearchTensorBoardModal); const isMobile = useMobile(); const { openToast } = useToast(); @@ -201,32 +217,50 @@ const TableActionBar: React.FC = ({ ); const availableBatchActions = useMemo(() => { - const experiments = selectedExperimentIds.map((id) => experimentMap[id]) ?? []; - return getActionsForExperimentsUnion(experiments, [...batchActions], permissions); - // Spreading batchActions is so TypeScript doesn't complain that it's readonly. - }, [selectedExperimentIds, experimentMap, permissions]); + switch (selection.type) { + case 'ONLY_IN': { + const experiments = selection.selections.map((id) => experimentMap[id]) ?? []; + return getActionsForExperimentsUnion(experiments, [...batchActions], permissions); // Spreading batchActions is so TypeScript doesn't complain that it's readonly. + } + case 'ALL_EXCEPT': + return batchActions; + } + }, [selection, permissions, experimentMap]); const sendBatchActions = useCallback( async (action: BatchAction): Promise => { - const validExperimentIds = selectedExperiments - .filter((exp) => !exp.unmanaged && canActionExperiment(action, exp)) - .map((exp) => exp.id); - const params = { - experimentIds: validExperimentIds, - projectId: project.id, - }; + const params: SearchBulkActionParams = { projectId: project.id }; + switch (selection.type) { + case 'ONLY_IN': { + const validSearchIds = selectedExperiments + .filter((exp) => !exp.unmanaged && canActionExperiment(action, exp)) + .map((exp) => exp.id); + params.searchIds = validSearchIds; + break; + } + case 'ALL_EXCEPT': { + const filterFormSet = JSON.parse(tableFilterString) as FilterFormSet; + params.filter = JSON.stringify(getIdsFilter(filterFormSet, selection)); + break; + } + } + switch (action) { case ExperimentAction.OpenTensorBoard: { - if (validExperimentIds.length !== selectedExperiments.length) { - // if unmanaged experiments are selected, open experimentTensorBoardModal - openExperimentTensorBoardModal(); - } else { + if ( + params.searchIds === undefined || + params.searchIds.length === selectedExperiments.length + ) { openCommandResponse( - await openOrCreateTensorBoard({ - experimentIds: params.experimentIds, + await openOrCreateTensorBoardSearches({ + filter: params.filter, + searchIds: params.searchIds, workspaceId: project?.workspaceId, }), ); + } else { + // if unmanaged experiments are selected, open searchTensorBoardModal + openSearchTensorBoardModal(); } return; } @@ -235,27 +269,30 @@ const TableActionBar: React.FC = ({ case ExperimentAction.RetainLogs: return ExperimentRetainLogsModal.open(); case ExperimentAction.Activate: - return await activateExperiments(params); + return await resumeSearches(params); case ExperimentAction.Archive: - return await archiveExperiments(params); + return await archiveSearches(params); case ExperimentAction.Cancel: - return await cancelExperiments(params); + return await cancelSearches(params); case ExperimentAction.Kill: - return await killExperiments(params); + return await killSearches(params); case ExperimentAction.Pause: - return await pauseExperiments(params); + return await pauseSearches(params); case ExperimentAction.Unarchive: - return await unarchiveExperiments(params); + return await unarchiveSearches(params); case ExperimentAction.Delete: - return await deleteExperiments(params); + return await deleteSearches(params); } }, [ + project.id, + project?.workspaceId, + selection, selectedExperiments, + tableFilterString, ExperimentMoveModal, ExperimentRetainLogsModal, - openExperimentTensorBoardModal, - project, + openSearchTensorBoardModal, ], ); @@ -312,8 +349,7 @@ const TableActionBar: React.FC = ({ closeable: true, description: `${action} succeeded for ${numSuccesses} out of ${ numFailures + numSuccesses - } eligible - ${labelPlural.toLowerCase()}`, + } ${labelPlural.toLowerCase()}`, severity: 'Warning', title: `Partial ${action} Failure`, }); @@ -376,8 +412,6 @@ const TableActionBar: React.FC = ({ }, [] as MenuItem[]); }, [availableBatchActions]); - const handleAction = useCallback((key: string) => handleBatchAction(key), [handleBatchAction]); - return (
@@ -413,8 +447,8 @@ const TableActionBar: React.FC = ({ onVisibleColumnChange={onVisibleColumnChange} /> - {selectedExperimentIds.length > 0 && ( - + {selectionSize > 0 && ( + @@ -423,8 +457,11 @@ const TableActionBar: React.FC = ({ @@ -460,13 +497,11 @@ const TableActionBar: React.FC = ({ /> )} - canActionExperiment(ExperimentAction.Move, experimentMap[id]) && - permissions.canMoveExperiment({ experiment: experimentMap[id] }), - )} + selection={selection} + selectionSize={selectionSize} sourceProjectId={project.id} sourceWorkspaceId={project.workspaceId} + tableFilters={tableFilterString} onSubmit={handleSubmitMove} /> = ({ projectId={project.id} onSubmit={handleSubmitRetainLogs} /> -
diff --git a/webui/react/src/e2e/models/common/hew/DataGrid.ts b/webui/react/src/e2e/models/common/hew/DataGrid.ts index 98722cc9be7..8b2484f352f 100644 --- a/webui/react/src/e2e/models/common/hew/DataGrid.ts +++ b/webui/react/src/e2e/models/common/hew/DataGrid.ts @@ -5,7 +5,6 @@ import { } from 'playwright-page-model-base/BaseComponent'; import { expect } from 'e2e/fixtures/global-fixtures'; -import { DropdownMenu } from 'e2e/models/common/hew/Dropdown'; import { printMap } from 'e2e/utils/debug'; class IndexNotFoundError extends Error {} @@ -396,14 +395,6 @@ export class HeadRow>> extends NamedCompone parent: this, selector: 'th', }); - readonly selectDropdown = new HeaderDropdown({ - clickThisComponentToOpen: new BaseComponent({ - parent: this, - selector: `[${DataGrid.columnIndexAttribute}="1"]`, - }), - openMethod: this.clickSelectDropdown.bind(this), - root: this.root, - }); #columnDefs = new Map(); @@ -473,18 +464,8 @@ export class HeadRow>> extends NamedCompone /** * Clicks the head row's select button */ - async clickSelectDropdown(): Promise { + async clickSelectHeader(): Promise { // magic numbers for the select button await this.parentTable.pwLocator.click({ position: { x: 5, y: 5 } }); } } - -/** - * Represents the HeaderDropdown from the DataGrid component - */ -class HeaderDropdown extends DropdownMenu { - readonly select5 = this.menuItem('select-5'); - readonly select10 = this.menuItem('select-10'); - readonly select25 = this.menuItem('select-25'); - readonly selectAll = this.menuItem('select-all'); -} diff --git a/webui/react/src/e2e/models/components/TableActionBar.ts b/webui/react/src/e2e/models/components/TableActionBar.ts index 432d6a3df22..17acec62672 100644 --- a/webui/react/src/e2e/models/components/TableActionBar.ts +++ b/webui/react/src/e2e/models/components/TableActionBar.ts @@ -25,6 +25,7 @@ export class TableActionBar extends NamedComponent { count = new BaseComponent({ parent: this, selector: '[data-test="count"]' }); heatmapToggle = new BaseComponent({ parent: this, selector: '[data-test="heatmapToggle"]' }); compare = new BaseComponent({ parent: this, selector: '[data-test="compare"]' }); + clearSelection = new BaseComponent({ parent: this, selector: '[data-test="clear-selection"]' }); // TODO a bunch of modals } diff --git a/webui/react/src/e2e/models/pages/ProjectDetails.ts b/webui/react/src/e2e/models/pages/ProjectDetails.ts index 9c2aa63530d..8aa2d812552 100644 --- a/webui/react/src/e2e/models/pages/ProjectDetails.ts +++ b/webui/react/src/e2e/models/pages/ProjectDetails.ts @@ -5,7 +5,7 @@ import { F_ExperimentList } from 'e2e/models/components/F_ExperimentList'; import { PageComponent } from 'e2e/models/components/Page'; /** - * Represents the SignIn page from src/pages/ProjectDetails.tsx + * Represents the ProjectDetails page from src/pages/ProjectDetails.tsx */ export class ProjectDetails extends DeterminedPage { readonly title = /Uncategorized Experiments|Project Details/; @@ -35,6 +35,8 @@ export class ProjectDetails extends DeterminedPage { return Number(matches[1]); } + // async getRowsSelected(): Promise<{ selected: number; total?: number }> {} + readonly pageComponent = new PageComponent({ parent: this }); readonly dynamicTabs = new DynamicTabs({ parent: this.pageComponent }); readonly runsTab = this.dynamicTabs.pivot.tab('runs'); diff --git a/webui/react/src/e2e/tests/experimentList.spec.ts b/webui/react/src/e2e/tests/experimentList.spec.ts index e2bcc21a50b..cd404d82e70 100644 --- a/webui/react/src/e2e/tests/experimentList.spec.ts +++ b/webui/react/src/e2e/tests/experimentList.spec.ts @@ -13,7 +13,7 @@ test.describe('Experiment List', () => { const getCount = async () => { const count = await projectDetailsPage.f_experimentList.tableActionBar.count.pwLocator.textContent(); - if (count === null) throw new Error('Count is null'); + if (count === null) return 0; return parseInt(count); }; @@ -56,11 +56,14 @@ test.describe('Experiment List', () => { timeout: 10_000, }); await test.step('Deselect', async () => { - try { - await grid.headRow.selectDropdown.menuItem('select-none').select({ timeout: 1_000 }); - } catch (e) { - // close the dropdown by clicking elsewhere - await projectDetailsPage.f_experimentList.tableActionBar.count.pwLocator.click(); + const count = await getCount(); + if (count !== 0) { + await grid.headRow.clickSelectHeader(); + const isClearSelectionVisible = + await projectDetailsPage.f_experimentList.tableActionBar.clearSelection.pwLocator.isVisible(); + if (isClearSelectionVisible) { + await projectDetailsPage.f_experimentList.tableActionBar.clearSelection.pwLocator.click(); + } } }); await test.step('Reset Columns', async () => { @@ -296,11 +299,6 @@ test.describe('Experiment List', () => { await test.step('Read Cell Value', async () => { await expect.soft((await row.getCellByColumnName('ID')).pwLocator).toHaveText(/\d+/); }); - await test.step('Select 5', async () => { - await ( - await projectDetailsPage.f_experimentList.dataGrid.headRow.selectDropdown.open() - ).select5.pwLocator.click(); - }); await test.step('Experiment Overview Navigation', async () => { await projectDetailsPage.f_experimentList.dataGrid.scrollLeft(); const textContent = await (await row.getCellByColumnName('ID')).pwLocator.textContent(); diff --git a/webui/react/src/hooks/useSelection.ts b/webui/react/src/hooks/useSelection.ts new file mode 100644 index 00000000000..5fb6f7e4c5a --- /dev/null +++ b/webui/react/src/hooks/useSelection.ts @@ -0,0 +1,201 @@ +import { CompactSelection, GridSelection } from '@glideapps/glide-data-grid'; +import { + HandleSelectionChangeType, + RangelessSelectionType, + SelectionType, +} from 'hew/DataGrid/DataGrid'; +import { Loadable } from 'hew/utils/loadable'; +import * as t from 'io-ts'; +import { useCallback, useMemo } from 'react'; + +import { RegularSelectionType, SelectionType as SelectionState } from 'types'; + +export const DEFAULT_SELECTION: t.TypeOf = { + selections: [], + type: 'ONLY_IN', +}; + +interface HasId { + id: number; +} + +interface SelectionConfig { + records: Loadable[]; + selection: SelectionState; + total: Loadable; + updateSettings: (p: Record) => void; +} + +interface UseSelectionReturn { + selectionSize: number; + dataGridSelection: GridSelection; + handleSelectionChange: HandleSelectionChangeType; + rowRangeToIds: (range: [number, number]) => number[]; + loadedSelectedRecords: T[]; + loadedSelectedRecordIds: number[]; + isRangeSelected: (range: [number, number]) => boolean; +} + +const useSelection = (config: SelectionConfig): UseSelectionReturn => { + const loadedRecordIdMap = useMemo(() => { + const recordMap = new Map(); + + config.records.forEach((r, index) => { + Loadable.forEach(r, (record) => { + recordMap.set(record.id, { index, record }); + }); + }); + return recordMap; + }, [config.records]); + + const selectedRecordIdSet = useMemo(() => { + switch (config.selection.type) { + case 'ONLY_IN': + return new Set(config.selection.selections); + case 'ALL_EXCEPT': { + const excludedSet = new Set(config.selection.exclusions); + return new Set( + Loadable.filterNotLoaded(config.records, (record) => !excludedSet.has(record.id)).map( + (record) => record.id, + ), + ); + } + } + }, [config.records, config.selection]); + + const dataGridSelection = useMemo(() => { + let rows = CompactSelection.empty(); + switch (config.selection.type) { + case 'ONLY_IN': + config.selection.selections.forEach((id) => { + const incIndex = loadedRecordIdMap.get(id)?.index; + if (incIndex !== undefined) { + rows = rows.add(incIndex); + } + }); + break; + case 'ALL_EXCEPT': + rows = rows.add([0, config.total.getOrElse(1) - 1]); + config.selection.exclusions.forEach((exc) => { + const excIndex = loadedRecordIdMap.get(exc)?.index; + if (excIndex !== undefined) { + rows = rows.remove(excIndex); + } + }); + break; + } + return { + columns: CompactSelection.empty(), + rows, + }; + }, [loadedRecordIdMap, config.selection, config.total]); + + const loadedSelectedRecords: T[] = useMemo(() => { + return Loadable.filterNotLoaded(config.records, (record) => selectedRecordIdSet.has(record.id)); + }, [config.records, selectedRecordIdSet]); + + const loadedSelectedRecordIds: number[] = useMemo(() => { + return loadedSelectedRecords.map((record) => record.id); + }, [loadedSelectedRecords]); + + const selectionSize = useMemo(() => { + switch (config.selection.type) { + case 'ONLY_IN': + return config.selection.selections.length; + case 'ALL_EXCEPT': + return config.total.getOrElse(0) - config.selection.exclusions.length; + } + }, [config.selection, config.total]); + + const rowRangeToIds = useCallback( + (range: [number, number]) => { + const slice = config.records.slice(range[0], range[1]); + return Loadable.filterNotLoaded(slice).map((run) => run.id); + }, + [config.records], + ); + + const handleSelectionChange: HandleSelectionChangeType = useCallback( + (selectionType: SelectionType | RangelessSelectionType, range?: [number, number]) => { + let newSettings: SelectionState = { ...config.selection }; + + switch (selectionType) { + case 'add': + if (!range) return; + if (newSettings.type === 'ALL_EXCEPT') { + const excludedSet = new Set(newSettings.exclusions); + rowRangeToIds(range).forEach((id) => excludedSet.delete(id)); + newSettings.exclusions = Array.from(excludedSet); + } else { + const includedSet = new Set(newSettings.selections); + rowRangeToIds(range).forEach((id) => includedSet.add(id)); + newSettings.selections = Array.from(includedSet); + } + + break; + case 'add-all': + newSettings = { + exclusions: [], + type: 'ALL_EXCEPT', + }; + + break; + case 'remove': + if (!range) return; + if (newSettings.type === 'ALL_EXCEPT') { + const excludedSet = new Set(newSettings.exclusions); + rowRangeToIds(range).forEach((id) => excludedSet.add(id)); + newSettings.exclusions = Array.from(excludedSet); + } else { + const includedSet = new Set(newSettings.selections); + rowRangeToIds(range).forEach((id) => includedSet.delete(id)); + newSettings.selections = Array.from(includedSet); + } + + break; + case 'remove-all': + newSettings = DEFAULT_SELECTION; + + break; + case 'set': + if (!range) return; + newSettings = { + ...DEFAULT_SELECTION, + selections: Array.from(rowRangeToIds(range)), + }; + + break; + } + config.updateSettings({ selection: newSettings }); + }, + [config, rowRangeToIds], + ); + + const isRangeSelected = useCallback( + (range: [number, number]): boolean => { + switch (config.selection.type) { + case 'ONLY_IN': { + const includedSet = new Set(config.selection.selections); + return rowRangeToIds(range).every((id) => includedSet.has(id)); + } + case 'ALL_EXCEPT': { + const excludedSet = new Set(config.selection.exclusions); + return rowRangeToIds(range).every((id) => !excludedSet.has(id)); + } + } + }, + [rowRangeToIds, config.selection], + ); + + return { + dataGridSelection, + handleSelectionChange, + isRangeSelected, + loadedSelectedRecordIds, + loadedSelectedRecords, + rowRangeToIds, + selectionSize, + }; +}; + +export default useSelection; diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx b/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx index b6b52f8b1ed..4e6ff2c2f0b 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx +++ b/webui/react/src/pages/ExperimentDetails/ExperimentDetailsHeader.tsx @@ -785,6 +785,7 @@ const ExperimentDetailsHeader: React.FC = ({ = ({ project }) => { /> = ({ project }) => { const isMobile = useMobile(); const { openToast } = useToast(); + const { + ui: { theme: appTheme }, + isDarkMode, + } = useUI(); + + const { + selectionSize, + dataGridSelection, + handleSelectionChange, + isRangeSelected, + loadedSelectedRecordIds: loadedSelectedExperimentIds, + } = useSelection({ + records: experiments.map((loadable) => loadable.map((exp) => exp.experiment)), + selection: settings.selection, + total, + updateSettings, + }); + const handlePinnedColumnsCountChange = useCallback( (newCount: number) => updateSettings({ pinnedColumnsCount: newCount }), [updateSettings], @@ -248,38 +255,7 @@ const F_ExperimentList: React.FC = ({ project }) => { const [error] = useState(false); const [canceler] = useState(new AbortController()); - const allSelectedExperimentIds = useMemo(() => { - return settings.selection.type === 'ONLY_IN' ? settings.selection.selections : []; - }, [settings.selection]); - - const loadedSelectedExperimentIds = useMemo(() => { - const selectedMap = new Map(); - if (isLoadingSettings) { - return selectedMap; - } - const selectedIdSet = new Set(allSelectedExperimentIds); - experiments.forEach((e, index) => { - Loadable.forEach(e, ({ experiment }) => { - if (selectedIdSet.has(experiment.id)) { - selectedMap.set(experiment.id, { experiment, index }); - } - }); - }); - return selectedMap; - }, [isLoadingSettings, allSelectedExperimentIds, experiments]); - - const selection = useMemo(() => { - let rows = CompactSelection.empty(); - loadedSelectedExperimentIds.forEach((info) => { - rows = rows.add(info.index); - }); - return { - columns: CompactSelection.empty(), - rows, - }; - }, [loadedSelectedExperimentIds]); - - const colorMap = useGlasbey([...loadedSelectedExperimentIds.keys()]); + const colorMap = useGlasbey(loadedSelectedExperimentIds); const { width: containerWidth } = useResize(contentRef); const experimentFilters = useMemo(() => { @@ -437,71 +413,6 @@ const F_ExperimentList: React.FC = ({ project }) => { }; }, [canceler, stopPolling]); - const rowRangeToIds = useCallback( - (range: [number, number]) => { - const slice = experiments.slice(range[0], range[1]); - return Loadable.filterNotLoaded(slice).map(({ experiment }) => experiment.id); - }, - [experiments], - ); - - const handleSelectionChange: HandleSelectionChangeType = useCallback( - (selectionType: SelectionType | RangelessSelectionType, range?: [number, number]) => { - let newSettings: SelectionState = { ...settings.selection }; - - switch (selectionType) { - case 'add': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.delete(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.add(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'add-all': - newSettings = { - exclusions: [], - type: 'ALL_EXCEPT' as const, - }; - - break; - case 'remove': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.add(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.delete(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'remove-all': - newSettings = DEFAULT_SELECTION; - - break; - case 'set': - if (!range) return; - newSettings = { - ...DEFAULT_SELECTION, - selections: Array.from(rowRangeToIds(range)), - }; - - break; - } - - updateSettings({ selection: newSettings }); - }, - [rowRangeToIds, settings.selection, updateSettings], - ); - const handleActionComplete = useCallback(async () => { /** * Deselect selected rows since their states may have changed where they @@ -576,6 +487,14 @@ const F_ExperimentList: React.FC = ({ project }) => { [handleSelectionChange, openToast], ); + const handleActualSelectAll = useCallback(() => { + handleSelectionChange?.('add-all'); + }, [handleSelectionChange]); + + const handleClearSelect = useCallback(() => { + handleSelectionChange?.('remove-all'); + }, [handleSelectionChange]); + const handleContextMenuComplete = useCallback( (action: ExperimentAction, id: number, data?: Partial) => handleActionSuccess(action, [id], data), @@ -734,11 +653,6 @@ const F_ExperimentList: React.FC = ({ project }) => { ); }, [isMobile, settings.compare, settings.pinnedColumnsCount]); - const { - ui: { theme: appTheme }, - isDarkMode, - } = useUI(); - const users = useObservable(usersStore.getUsers()); const columns: ColumnDef[] = useMemo(() => { @@ -761,7 +675,7 @@ const F_ExperimentList: React.FC = ({ project }) => { ) .map((columnName) => { if (columnName === MULTISELECT) { - return (columnDefs[columnName] = defaultSelectionColumn(selection.rows, false)); + return (columnDefs[columnName] = defaultSelectionColumn(dataGridSelection.rows, false)); } if (!Loadable.isLoaded(projectColumnsMap)) { @@ -892,49 +806,36 @@ const F_ExperimentList: React.FC = ({ project }) => { .flatMap((col) => (col ? [col] : [])); return gridColumns; }, [ - settings.compare, - settings.pinnedColumnsCount, projectColumns, + appTheme, settings.columnWidths, - settings.heatmapSkipped, - projectHeatmap, + settings.compare, + settings.pinnedColumnsCount, settings.heatmapOn, - columnsIfLoaded, - appTheme, + settings.heatmapSkipped, isDarkMode, - selection.rows, users, + columnsIfLoaded, + dataGridSelection.rows, + projectHeatmap, ]); + const handleHeaderClick = useCallback( + (columnId: string): void => { + if (columnId === MULTISELECT) { + if (isRangeSelected([0, settings.pageLimit])) { + handleSelectionChange?.('remove', [0, settings.pageLimit]); + } else { + handleSelectionChange?.('add', [0, settings.pageLimit]); + } + } + }, + [handleSelectionChange, isRangeSelected, settings.pageLimit], + ); + const getHeaderMenuItems = (columnId: string, colIdx: number): MenuItem[] => { if (columnId === MULTISELECT) { - const items: MenuItem[] = [ - settings.selection.type === 'ALL_EXCEPT' || settings.selection.selections.length > 0 - ? { - key: 'select-none', - label: 'Clear selected', - onClick: () => { - handleSelectionChange?.('remove-all'); - }, - } - : null, - ...[5, 10, 25].map((n) => ({ - key: `select-${n}`, - label: `Select first ${n}`, - onClick: () => { - handleSelectionChange?.('set', [0, n]); - dataGridRef.current?.scrollToTop(); - }, - })), - { - key: 'select-all', - label: 'Select all', - onClick: () => { - handleSelectionChange?.('add', [0, settings.pageLimit]); - }, - }, - ]; - return items; + return []; } const column = Loadable.getOrElse([], projectColumns).find((c) => c.column === columnId); if (!column) { @@ -1096,11 +997,16 @@ const F_ExperimentList: React.FC = ({ project }) => { project={project} projectColumns={projectColumns} rowHeight={globalSettings.rowHeight} - selectedExperimentIds={allSelectedExperimentIds} + selectedExperimentIds={loadedSelectedExperimentIds} + selection={settings.selection} + selectionSize={selectionSize} sorts={sorts} + tableFilterString={filtersString} total={total} onActionComplete={handleActionComplete} onActionSuccess={handleActionSuccess} + onActualSelectAll={handleActualSelectAll} + onClearSelect={handleClearSelect} onComparisonViewToggle={handleToggleComparisonView} onHeatmapSelectionRemove={(id) => { const newSelection = settings.heatmapSkipped.filter((s) => s !== id); @@ -1130,6 +1036,7 @@ const F_ExperimentList: React.FC = ({ project }) => { initialWidth={comparisonViewTableWidth} open={settings.compare} projectId={project.id} + tableFilters={filtersString} onWidthChange={handleCompareWidthChange}> columns={columns} @@ -1165,12 +1072,13 @@ const F_ExperimentList: React.FC = ({ project }) => { ); }} rowHeight={rowHeightMap[globalSettings.rowHeight]} - selection={selection} + selection={dataGridSelection} sorts={sorts} staticColumns={STATIC_COLUMNS} onColumnResize={handleColumnWidthChange} onColumnsOrderChange={handleColumnsOrderChange} onContextMenuComplete={handleContextMenuComplete} + onHeaderClicked={handleHeaderClick} onPinnedColumnsCountChange={handlePinnedColumnsCountChange} onSelectionChange={handleSelectionChange} /> diff --git a/webui/react/src/pages/FlatRuns/FlatRunActionButton.test.tsx b/webui/react/src/pages/FlatRuns/FlatRunActionButton.test.tsx index f6f9bd4dced..de659e91200 100644 --- a/webui/react/src/pages/FlatRuns/FlatRunActionButton.test.tsx +++ b/webui/react/src/pages/FlatRuns/FlatRunActionButton.test.tsx @@ -25,9 +25,12 @@ const setup = (selectedFlatRuns: ReadonlyArray>) => { render( run.id), type: 'ONLY_IN' }} + selectionSize={selectedFlatRuns.length} workspaceId={1} onActionComplete={onActionComplete} onActionSuccess={onActionSuccess} diff --git a/webui/react/src/pages/FlatRuns/FlatRunActionButton.tsx b/webui/react/src/pages/FlatRuns/FlatRunActionButton.tsx index 9dc726499a8..6b8b8820f6c 100644 --- a/webui/react/src/pages/FlatRuns/FlatRunActionButton.tsx +++ b/webui/react/src/pages/FlatRuns/FlatRunActionButton.tsx @@ -9,6 +9,7 @@ import { useObservable } from 'micro-observables'; import { useCallback, useMemo, useState } from 'react'; import BatchActionConfirmModalComponent from 'components/BatchActionConfirmModal'; +import { FilterFormSetWithoutId } from 'components/FilterForm/components/type'; import Link from 'components/Link'; import usePermissions from 'hooks/usePermissions'; import FlatRunMoveModalComponent from 'pages/FlatRuns/FlatRunMoveModal'; @@ -21,10 +22,11 @@ import { resumeRuns, unarchiveRuns, } from 'services/api'; +import { RunBulkActionParams } from 'services/types'; import projectStore from 'stores/projects'; -import { BulkActionResult, ExperimentAction, FlatRun, Project } from 'types'; +import { BulkActionResult, ExperimentAction, FlatRun, Project, SelectionType } from 'types'; import handleError from 'utils/error'; -import { canActionFlatRun, getActionsForFlatRunsUnion } from 'utils/flatRun'; +import { canActionFlatRun, getActionsForFlatRunsUnion, getIdsFilter } from 'utils/flatRun'; import { capitalizeWord, pluralizer } from 'utils/string'; const BATCH_ACTIONS = [ @@ -52,18 +54,24 @@ const ACTION_ICONS: Record = { const LABEL_PLURAL = 'runs'; interface Props { + filter: string; isMobile: boolean; selectedRuns: ReadonlyArray>; projectId: number; workspaceId: number; onActionSuccess?: (action: BatchAction, successfulIds: number[]) => void; onActionComplete?: () => void | Promise; + selection: SelectionType; + selectionSize: number; } const FlatRunActionButton = ({ + filter, isMobile, selectedRuns, projectId, + selection, + selectionSize, workspaceId, onActionSuccess, onActionComplete, @@ -80,13 +88,21 @@ const FlatRunActionButton = ({ const sendBatchActions = useCallback( async (action: BatchAction): Promise => { - const validRunIds = selectedRuns - .filter((exp) => canActionFlatRun(action, exp)) - .map((run) => run.id); - const params = { - projectId, - runIds: validRunIds, - }; + const params: RunBulkActionParams = { projectId }; + switch (selection.type) { + case 'ONLY_IN': { + const validRunIds = selectedRuns + .filter((run) => canActionFlatRun(action, run)) + .map((run) => run.id); + params.runIds = validRunIds; + break; + } + case 'ALL_EXCEPT': { + const filterFormSet = JSON.parse(filter) as FilterFormSetWithoutId; + params.filter = JSON.stringify(getIdsFilter(filterFormSet, selection)); + break; + } + } switch (action) { case ExperimentAction.Move: flatRunMoveModalOpen(); @@ -105,7 +121,7 @@ const FlatRunActionButton = ({ return await resumeRuns(params); } }, - [flatRunMoveModalOpen, projectId, selectedRuns], + [flatRunMoveModalOpen, projectId, selectedRuns, selection, filter], ); const submitBatchAction = useCallback( @@ -139,8 +155,7 @@ const FlatRunActionButton = ({ } else { openToast({ closeable: true, - description: `${action} succeeded for ${numSuccesses} out of ${numFailures + numSuccesses} eligible - ${pluralizer(numFailures + numSuccesses, 'run')}`, + description: `${action} succeeded for ${numSuccesses} out of ${numFailures + numSuccesses} ${pluralizer(numFailures + numSuccesses, 'run')}`, severity: 'Warning', title: `Partial ${action} Failure`, }); @@ -175,8 +190,13 @@ const FlatRunActionButton = ({ ); const availableBatchActions = useMemo(() => { - return getActionsForFlatRunsUnion(selectedRuns, [...BATCH_ACTIONS], permissions); - }, [selectedRuns, permissions]); + switch (selection.type) { + case 'ONLY_IN': + return getActionsForFlatRunsUnion(selectedRuns, [...BATCH_ACTIONS], permissions); + case 'ALL_EXCEPT': + return BATCH_ACTIONS; + } + }, [selection.type, selectedRuns, permissions]); const editMenuItems = useMemo(() => { const groupedBatchActions = [BATCH_ACTIONS]; @@ -197,7 +217,7 @@ const FlatRunActionButton = ({ }, []); }, [availableBatchActions]); - const onSubmit = useCallback( + const onSubmitMove = useCallback( async (results: BulkActionResult, destinationProjectId: number) => { const numSuccesses = results?.successful.length ?? 0; const numFailures = results?.failed.length ?? 0; @@ -241,7 +261,7 @@ const FlatRunActionButton = ({ return ( <> - {selectedRuns.length > 0 && ( + {selectionSize > 0 && ( diff --git a/webui/react/src/pages/FlatRuns/FlatRunMoveModal.tsx b/webui/react/src/pages/FlatRuns/FlatRunMoveModal.tsx index 68f3544f87d..c60b305bed3 100644 --- a/webui/react/src/pages/FlatRuns/FlatRunMoveModal.tsx +++ b/webui/react/src/pages/FlatRuns/FlatRunMoveModal.tsx @@ -10,6 +10,8 @@ import { List } from 'immutable'; import { useObservable } from 'micro-observables'; import React, { Ref, useCallback, useEffect, useId, useRef } from 'react'; +import { INIT_FORMSET } from 'components/FilterForm/components/FilterFormStore'; +import { FilterFormSet } from 'components/FilterForm/components/type'; import RunFilterInterstitialModalComponent, { ControlledModalRef, } from 'components/RunFilterInterstitialModalComponent'; @@ -19,10 +21,12 @@ import RunMoveWarningModalComponent, { import usePermissions from 'hooks/usePermissions'; import { formStore } from 'pages/FlatRuns/FlatRuns'; import { moveRuns } from 'services/api'; +import { V1MoveRunsRequest } from 'services/api-ts-sdk'; import projectStore from 'stores/projects'; import workspaceStore from 'stores/workspaces'; -import { BulkActionResult, FlatRun, Project } from 'types'; +import { BulkActionResult, Project, SelectionType, XOR } from 'types'; import handleError from 'utils/error'; +import { getIdsFilter as getRunIdsFilter } from 'utils/flatRun'; import { pluralizer } from 'utils/string'; const FORM_ID = 'move-flat-run-form'; @@ -32,15 +36,21 @@ type FormInputs = { destinationWorkspaceId?: number; }; -interface Props { - flatRuns: Readonly[]; +interface BaseProps { + selectionSize: number; sourceProjectId: number; sourceWorkspaceId?: number; onSubmit?: (results: BulkActionResult, destinationProjectId: number) => void | Promise; } +type Props = BaseProps & + XOR<{ runIds: number[] }, { selection: SelectionType; tableFilters: string }>; + const FlatRunMoveModalComponent: React.FC = ({ - flatRuns, + runIds, + tableFilters, + selection, + selectionSize, sourceProjectId, sourceWorkspaceId, onSubmit, @@ -97,24 +107,38 @@ const FlatRunMoveModalComponent: React.FC = ({ return; } - const results = await moveRuns({ + const moveRunsArgs: V1MoveRunsRequest = { destinationProjectId: projId, - runIds: flatRuns.map((flatRun) => flatRun.id), sourceProjectId, - }); + }; + + if (tableFilters !== undefined) { + const filterFormSet = + selection.type === 'ALL_EXCEPT' + ? (JSON.parse(tableFilters) as FilterFormSet) + : INIT_FORMSET; + const filter = getRunIdsFilter(filterFormSet, selection); + moveRunsArgs.filter = JSON.stringify(filter); + } else { + moveRunsArgs.runIds = runIds; + } + + const results = await moveRuns(moveRunsArgs); await onSubmit?.(results, projId); form.resetFields(); } catch (e) { handleError(e, { publicSubject: 'Unable to move runs' }); } }, [ - flatRuns, form, - onSubmit, - openToast, - sourceProjectId, - sourceWorkspaceId, destinationWorkspaceId, + sourceWorkspaceId, + sourceProjectId, + openToast, + tableFilters, + onSubmit, + selection, + runIds, ]); return ( @@ -127,9 +151,9 @@ const FlatRunMoveModalComponent: React.FC = ({ form: idPrefix + FORM_ID, handleError, handler: handleSubmit, - text: `Move ${pluralizer(flatRuns.length, 'Run')}`, + text: `Move ${pluralizer(selectionSize, 'Run')}`, }} - title={`Move ${pluralizer(flatRuns.length, 'Run')}`}> + title={`Move ${pluralizer(selectionSize, 'Run')}`}>
= ({ flatRun.id), type: 'ONLY_IN' }} + selection={selection ?? { selections: runIds, type: 'ONLY_IN' }} /> ); diff --git a/webui/react/src/pages/FlatRuns/FlatRuns.tsx b/webui/react/src/pages/FlatRuns/FlatRuns.tsx index 361ee1c8cec..e9fbc37c353 100644 --- a/webui/react/src/pages/FlatRuns/FlatRuns.tsx +++ b/webui/react/src/pages/FlatRuns/FlatRuns.tsx @@ -1,4 +1,3 @@ -import { CompactSelection, GridSelection } from '@glideapps/glide-data-grid'; import { isLeft } from 'fp-ts/lib/Either'; import Button from 'hew/Button'; import Column from 'hew/Column'; @@ -14,15 +13,7 @@ import { MULTISELECT, } from 'hew/DataGrid/columns'; import { ContextMenuCompleteHandlerProps } from 'hew/DataGrid/contextMenu'; -import DataGrid, { - DataGridHandle, - HandleSelectionChangeType, - RangelessSelectionType, - SelectionType, - Sort, - validSort, - ValidSort, -} from 'hew/DataGrid/DataGrid'; +import DataGrid, { DataGridHandle, Sort, validSort, ValidSort } from 'hew/DataGrid/DataGrid'; import { MenuItem } from 'hew/Dropdown'; import Icon from 'hew/Icon'; import Link from 'hew/Link'; @@ -38,10 +29,15 @@ import { v4 as uuidv4 } from 'uuid'; import ColumnPickerMenu from 'components/ColumnPickerMenu'; import ComparisonView from 'components/ComparisonView'; import { Error } from 'components/exceptions'; -import { FilterFormStore, ROOT_ID } from 'components/FilterForm/components/FilterFormStore'; +import { + FilterFormStore, + INIT_FORMSET, + ROOT_ID, +} from 'components/FilterForm/components/FilterFormStore'; import { AvailableOperators, FilterFormSet, + FilterFormSetWithoutId, FormField, FormGroup, FormKind, @@ -67,6 +63,7 @@ import useMobile from 'hooks/useMobile'; import usePolling from 'hooks/usePolling'; import useResize from 'hooks/useResize'; import useScrollbarWidth from 'hooks/useScrollbarWidth'; +import useSelection from 'hooks/useSelection'; import { useSettings } from 'hooks/useSettings'; import useTypedParams from 'hooks/useTypedParams'; import FlatRunActionButton from 'pages/FlatRuns/FlatRunActionButton'; @@ -75,15 +72,9 @@ import { getProjectColumns, getProjectNumericMetricsRange, searchRuns } from 'se import { V1ColumnType, V1LocationType, V1TableType } from 'services/api-ts-sdk'; import userStore from 'stores/users'; import userSettings from 'stores/userSettings'; -import { - DetailedUser, - FlatRun, - FlatRunAction, - ProjectColumn, - RunState, - SelectionType as SelectionState, -} from 'types'; +import { DetailedUser, FlatRun, FlatRunAction, ProjectColumn, RunState } from 'types'; import handleError from 'utils/error'; +import { combine } from 'utils/filterFormSet'; import { eagerSubscribe } from 'utils/observable'; import { pluralizer } from 'utils/string'; @@ -192,7 +183,7 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { _: () => [], Loaded: (formset: FilterFormSet) => formset.filterGroup.children, }); - const filtersString = useObservable(formStore.asJsonString); + const filtersString = useObservable(formStore.asJsonString) || JSON.stringify(INIT_FORMSET); const [total, setTotal] = useState>(NotLoaded); const isMobile = useMobile(); const [isLoading, setIsLoading] = useState(true); @@ -203,6 +194,19 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { const { openToast } = useToast(); const { width: containerWidth } = useResize(contentRef); + const { + selectionSize, + dataGridSelection, + handleSelectionChange, + loadedSelectedRecords: loadedSelectedRuns, + isRangeSelected, + } = useSelection({ + records: runs, + selection: settings.selection, + total, + updateSettings, + }); + const { ui: { theme: appTheme }, isDarkMode, @@ -248,10 +252,6 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { return new Set([...BANNED_SORT_COLUMNS, ...arrayTypeColumns]); }, [arrayTypeColumns]); - const selectedRunIdSet = useMemo(() => { - return new Set(settings.selection.type === 'ONLY_IN' ? settings.selection.selections : []); - }, [settings.selection]); - const columnsIfLoaded = useMemo( () => (isLoadingSettings ? [] : settings.columns), [isLoadingSettings, settings.columns], @@ -263,41 +263,18 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { ); }, [isMobile, settings.compare, settings.pinnedColumnsCount]); - const loadedSelectedRunIds = useMemo(() => { - const selectedMap = new Map(); - const selectedArray: FlatRun[] = []; - if (isLoadingSettings) { - return selectedMap; - } + const loadedRunIdMap = useMemo(() => { + const runMap = new Map(); runs.forEach((r, index) => { Loadable.forEach(r, (run) => { - if (selectedRunIdSet.has(run.id)) { - selectedMap.set(run.id, { index, run }); - selectedArray.push(run); - } + runMap.set(run.id, { index, run }); }); }); - return selectedMap; - }, [isLoadingSettings, runs, selectedRunIdSet]); + return runMap; + }, [runs]); - const selection = useMemo(() => { - let rows = CompactSelection.empty(); - loadedSelectedRunIds.forEach((info) => { - rows = rows.add(info.index); - }); - return { - columns: CompactSelection.empty(), - rows, - }; - }, [loadedSelectedRunIds]); - - const selectedRuns: FlatRun[] = useMemo(() => { - const selected = runs.flatMap((run) => { - return run.isLoaded && selectedRunIdSet.has(run.data.id) ? [run.data] : []; - }); - return selected; - }, [runs, selectedRunIdSet]); + const colorMap = useGlasbey([...loadedRunIdMap.keys()]); const handleIsOpenFilterChange = useCallback((newOpen: boolean) => { setIsOpenFilter(newOpen); @@ -306,8 +283,6 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { } }, []); - const colorMap = useGlasbey([...loadedSelectedRunIds.keys()]); - const handleToggleComparisonView = useCallback(() => { updateSettings({ compare: !settings.compare }); }, [settings.compare, updateSettings]); @@ -336,7 +311,7 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { ) .map((columnName) => { if (columnName === MULTISELECT) { - return defaultSelectionColumn(selection.rows, false); + return defaultSelectionColumn(dataGridSelection.rows, false); } if (!Loadable.isLoaded(projectColumnsMap)) { @@ -479,7 +454,7 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { isDarkMode, projectColumns, projectHeatmap, - selection.rows, + dataGridSelection.rows, settings.columnWidths, settings.compare, settings.heatmapOn, @@ -538,31 +513,30 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { setRuns(INITIAL_LOADING_RUNS); }, [setPage]); + const filterFormSetString = useMemo(() => { + const filter = JSON.parse(filtersString) as FilterFormSetWithoutId; + if (searchId) { + // only display trials for search + const searchFilter = { + columnName: 'experimentId', + kind: 'field' as const, + location: V1LocationType.RUN, + operator: Operator.Eq, + type: V1ColumnType.NUMBER, + value: searchId, + }; + filter.filterGroup = combine(filter.filterGroup, 'and', searchFilter); + } + return JSON.stringify(filter); + }, [filtersString, searchId]); + const fetchRuns = useCallback(async (): Promise => { if (isLoadingSettings || Loadable.isNotLoaded(loadableFormset)) return; try { - const filters = JSON.parse(filtersString); - if (searchId) { - // only display trials for search - const existingFilterGroup = { ...filters.filterGroup }; - const searchFilter = { - columnName: 'experimentId', - kind: 'field', - location: 'LOCATION_TYPE_RUN', - operator: '=', - type: 'COLUMN_TYPE_NUMBER', - value: searchId, - }; - filters.filterGroup = { - children: [existingFilterGroup, searchFilter], - conjunction: 'and', - kind: 'group', - }; - } const offset = page * settings.pageLimit; const response = await searchRuns( { - filter: JSON.stringify(filters), + filter: filterFormSetString, limit: settings.pageLimit, offset, projectId: projectId, @@ -586,16 +560,15 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { setIsLoading(false); } }, [ - canceler.signal, - filtersString, isLoadingSettings, loadableFormset, page, - projectId, - resetPagination, settings.pageLimit, + filterFormSetString, + projectId, sortString, - searchId, + canceler.signal, + resetPagination, ]); const { stopPolling } = usePolling(fetchRuns, { rerunOnNewFn: true }); @@ -703,70 +676,6 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { [settings.columnWidths, updateColumnWidths], ); - const rowRangeToIds = useCallback( - (range: [number, number]) => { - const slice = runs.slice(range[0], range[1]); - return Loadable.filterNotLoaded(slice).map((run) => run.id); - }, - [runs], - ); - - const handleSelectionChange: HandleSelectionChangeType = useCallback( - (selectionType: SelectionType | RangelessSelectionType, range?: [number, number]) => { - let newSettings: SelectionState = { ...settings.selection }; - - switch (selectionType) { - case 'add': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.delete(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.add(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'add-all': - newSettings = { - exclusions: [], - type: 'ALL_EXCEPT' as const, - }; - - break; - case 'remove': - if (!range) return; - if (newSettings.type === 'ALL_EXCEPT') { - const excludedSet = new Set(newSettings.exclusions); - rowRangeToIds(range).forEach((id) => excludedSet.add(id)); - newSettings.exclusions = Array.from(excludedSet); - } else { - const includedSet = new Set(newSettings.selections); - rowRangeToIds(range).forEach((id) => includedSet.delete(id)); - newSettings.selections = Array.from(includedSet); - } - - break; - case 'remove-all': - newSettings = DEFAULT_SELECTION; - - break; - case 'set': - if (!range) return; - newSettings = { - ...DEFAULT_SELECTION, - selections: Array.from(rowRangeToIds(range)), - }; - - break; - } - updateSettings({ selection: newSettings }); - }, - [rowRangeToIds, settings.selection, updateSettings], - ); - const onActionComplete = useCallback(async () => { handleSelectionChange('remove-all'); await fetchRuns(); @@ -870,36 +779,31 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { [updateSettings], ); + const handleActualSelectAll = useCallback(() => { + handleSelectionChange?.('add-all'); + }, [handleSelectionChange]); + + const handleClearSelect = useCallback(() => { + handleSelectionChange?.('remove-all'); + }, [handleSelectionChange]); + + const handleHeaderClick = useCallback( + (columnId: string): void => { + if (columnId === MULTISELECT) { + if (isRangeSelected([0, settings.pageLimit])) { + handleSelectionChange?.('remove', [0, settings.pageLimit]); + } else { + handleSelectionChange?.('add', [0, settings.pageLimit]); + } + } + }, + [handleSelectionChange, isRangeSelected, settings.pageLimit], + ); + const getHeaderMenuItems = useCallback( (columnId: string, colIdx: number): MenuItem[] => { if (columnId === MULTISELECT) { - const items: MenuItem[] = [ - settings.selection.type === 'ALL_EXCEPT' || settings.selection.selections.length > 0 - ? { - key: 'select-none', - label: 'Clear selected', - onClick: () => { - handleSelectionChange?.('remove-all'); - }, - } - : null, - ...[5, 10, 25].map((n) => ({ - key: `select-${n}`, - label: `Select first ${n}`, - onClick: () => { - handleSelectionChange?.('set', [0, n]); - dataGridRef.current?.scrollToTop(); - }, - })), - { - key: 'select-all', - label: 'Select all', - onClick: () => { - handleSelectionChange?.('add', [0, settings.pageLimit]); - }, - }, - ]; - return items; + return []; } const column = Loadable.getOrElse([], projectColumns).find((c) => c.column === columnId); @@ -1051,12 +955,9 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { bannedSortColumns, projectColumns, settings.pinnedColumnsCount, - settings.selection, - settings.pageLimit, settings.heatmapOn, settings.heatmapSkipped, isMobile, - handleSelectionChange, columnsIfLoaded, handleColumnsOrderChange, rootFilterChildren, @@ -1118,17 +1019,23 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { onRowHeightChange={onRowHeightChange} /> @@ -1177,6 +1084,8 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { open={settings.compare} projectId={projectId} runSelection={settings.selection} + searchId={searchId} + tableFilters={filtersString} onWidthChange={handleCompareWidthChange}> columns={columns} @@ -1208,12 +1117,13 @@ const FlatRuns: React.FC = ({ projectId, workspaceId, searchId }) => { ); }} rowHeight={rowHeightMap[globalSettings.rowHeight as RowHeight]} - selection={selection} + selection={dataGridSelection} sorts={sorts} staticColumns={STATIC_COLUMNS} onColumnResize={handleColumnWidthChange} onColumnsOrderChange={handleColumnsOrderChange} onContextMenuComplete={handleContextMenuComplete} + onHeaderClicked={handleHeaderClick} onPinnedColumnsCountChange={handlePinnedColumnsCountChange} onSelectionChange={handleSelectionChange} /> diff --git a/webui/react/src/services/api.ts b/webui/react/src/services/api.ts index 02a4134d631..a4e57fc0aed 100644 --- a/webui/react/src/services/api.ts +++ b/webui/react/src/services/api.ts @@ -7,7 +7,7 @@ import { DeterminedInfo, Telemetry } from 'stores/determinedInfo'; import { EmptyParams, RawJson, SingleEntityParams } from 'types'; import * as Type from 'types'; import { generateDetApi } from 'utils/service'; -import { tensorBoardMatchesSource } from 'utils/task'; +import { tensorBoardMatchesSource, tensorBoardSearchesMatchesSource } from 'utils/task'; /* Authentication */ @@ -493,6 +493,75 @@ export const changeExperimentLogRetention = generateDetApi< Type.BulkActionResult >(Config.changeExperimentLogRetention); +/* Searches */ + +export const archiveSearches = generateDetApi< + Api.V1ArchiveSearchesRequest, + Api.V1ArchiveSearchesResponse, + Type.BulkActionResult +>(Config.archiveSearches); + +export const deleteSearches = generateDetApi< + Api.V1DeleteSearchesRequest, + Api.V1DeleteSearchesResponse, + Type.BulkActionResult +>(Config.deleteSearches); + +export const killSearches = generateDetApi< + Api.V1KillSearchesRequest, + Api.V1KillSearchesResponse, + Type.BulkActionResult +>(Config.killSearches); + +export const moveSearches = generateDetApi< + Api.V1MoveSearchesRequest, + Api.V1MoveSearchesResponse, + Type.BulkActionResult +>(Config.moveSearches); + +export const unarchiveSearches = generateDetApi< + Api.V1UnarchiveSearchesRequest, + Api.V1UnarchiveSearchesResponse, + void +>(Config.unarchiveSearches); + +export const pauseSearches = generateDetApi< + Api.V1ResumeSearchesRequest, + Api.V1ResumeSearchesResponse, + Type.BulkActionResult +>(Config.pauseSearches); + +export const resumeSearches = generateDetApi< + Api.V1ResumeSearchesRequest, + Api.V1ResumeSearchesResponse, + Type.BulkActionResult +>(Config.resumeSearches); + +export const cancelSearches = generateDetApi< + Api.V1ResumeSearchesRequest, + Api.V1ResumeSearchesResponse, + Type.BulkActionResult +>(Config.resumeSearches); + +export const launchTensorBoardSearches = generateDetApi< + Service.LaunchTensorBoardSearchesParams, + Api.V1LaunchTensorboardSearchesResponse, + Type.CommandResponse +>(Config.launchTensorBoardSearches); + +export const openOrCreateTensorBoardSearches = async ( + params: Service.LaunchTensorBoardSearchesParams, +): Promise => { + const tensorboards = await getTensorBoards({}); + const match = tensorboards.find( + (tensorboard) => + !terminalCommandStates.has(tensorboard.state) && + tensorBoardSearchesMatchesSource(tensorboard, params), + ); + if (match) return { command: match, warnings: [V1LaunchWarning.CURRENTSLOTSEXCEEDED] }; + return launchTensorBoardSearches(params); +}; + /* Tasks */ export const getTask = generateDetApi< diff --git a/webui/react/src/services/apiConfig.ts b/webui/react/src/services/apiConfig.ts index b6c2b0ed1b8..2472eb6e06b 100644 --- a/webui/react/src/services/apiConfig.ts +++ b/webui/react/src/services/apiConfig.ts @@ -1153,6 +1153,104 @@ export const getTrialWorkloads: DetApi< ), }; +/* Searches */ + +export const archiveSearches: DetApi< + Api.V1ArchiveSearchesRequest, + Api.V1ArchiveSearchesResponse, + Type.BulkActionResult +> = { + name: 'archiveSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.archiveSearches(params, options), +}; + +export const deleteSearches: DetApi< + Api.V1DeleteSearchesRequest, + Api.V1DeleteSearchesResponse, + Type.BulkActionResult +> = { + name: 'deleteSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.deleteSearches(params, options), +}; + +export const killSearches: DetApi< + Api.V1KillSearchesRequest, + Api.V1KillSearchesResponse, + Type.BulkActionResult +> = { + name: 'killSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.killSearches(params, options), +}; + +export const moveSearches: DetApi< + Api.V1MoveSearchesRequest, + Api.V1MoveSearchesResponse, + Type.BulkActionResult +> = { + name: 'moveSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.moveSearches(params, options), +}; + +export const unarchiveSearches: DetApi< + Api.V1UnarchiveSearchesRequest, + Api.V1UnarchiveSearchesResponse, + Type.BulkActionResult +> = { + name: 'unarchiveSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.unarchiveSearches(params, options), +}; + +export const pauseSearches: DetApi< + Api.V1PauseSearchesRequest, + Api.V1PauseSearchesResponse, + Type.BulkActionResult +> = { + name: 'pauseSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.pauseSearches(params, options), +}; + +export const resumeSearches: DetApi< + Api.V1ResumeSearchesRequest, + Api.V1ResumeSearchesResponse, + Type.BulkActionResult +> = { + name: 'resumeSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.resumeSearches(params, options), +}; + +export const cancelSearches: DetApi< + Api.V1CancelSearchesRequest, + Api.V1CancelSearchesResponse, + Type.BulkActionResult +> = { + name: 'cancelSearches', + postProcess: (response) => decoder.mapV1ActionResults(response.results), + request: (params, options) => detApi.Internal.cancelSearches(params, options), +}; + +export const launchTensorBoardSearches: DetApi< + Service.LaunchTensorBoardSearchesParams, + Api.V1LaunchTensorboardSearchesResponse, + Type.CommandResponse +> = { + name: 'launchTensorBoard', + postProcess: (response) => { + return { + command: decoder.mapV1TensorBoard(response.tensorboard), + warnings: response.warnings || [], + }; + }, + request: (params: Service.LaunchTensorBoardSearchesParams) => + detApi.Internal.launchTensorboardSearches(params), +}; + /* Runs */ export const searchRuns: DetApi< diff --git a/webui/react/src/services/types.ts b/webui/react/src/services/types.ts index 3b41ec11778..01960ebd252 100644 --- a/webui/react/src/services/types.ts +++ b/webui/react/src/services/types.ts @@ -164,6 +164,18 @@ export interface SearchRunsParams extends PaginationParams { sort?: string; } +export interface RunBulkActionParams { + projectId: number; + filter?: string; + runIds?: number[]; +} + +export interface SearchBulkActionParams { + projectId: number; + filter?: string; + searchIds?: number[]; +} + export interface GetTaskParams { taskId: string; } @@ -288,6 +300,12 @@ export interface LaunchTensorBoardParams { filters?: Api.V1BulkExperimentFilters; } +export interface LaunchTensorBoardSearchesParams { + searchIds?: Array; + workspaceId?: number; + filter?: string; +} + export interface LaunchJupyterLabParams { config?: { description?: string; diff --git a/webui/react/src/utils/experiment.ts b/webui/react/src/utils/experiment.ts index 862bf0e0df7..032124abfa6 100644 --- a/webui/react/src/utils/experiment.ts +++ b/webui/react/src/utils/experiment.ts @@ -374,7 +374,7 @@ const idToFilter = (operator: Operator, id: number) => export const getIdsFilter = ( filterFormSet: FilterFormSetWithoutId, selection: SelectionType, -): FilterFormSetWithoutId | undefined => { +): FilterFormSetWithoutId => { const filterGroup: FilterFormSetWithoutId['filterGroup'] = selection.type === 'ALL_EXCEPT' ? { diff --git a/webui/react/src/utils/task.ts b/webui/react/src/utils/task.ts index 0c220bf9e56..f29c4ba6d61 100644 --- a/webui/react/src/utils/task.ts +++ b/webui/react/src/utils/task.ts @@ -1,7 +1,7 @@ import _ from 'lodash'; import { killableCommandStates, killableRunStates, terminalCommandStates } from 'constants/states'; -import { LaunchTensorBoardParams } from 'services/types'; +import { LaunchTensorBoardParams, LaunchTensorBoardSearchesParams } from 'services/types'; import * as Type from 'types'; import { CommandState, RunState, State } from 'types'; @@ -221,6 +221,23 @@ export const tensorBoardMatchesSource = ( return false; }; +// Checks whether tensorboard source matches a given source list. +export const tensorBoardSearchesMatchesSource = ( + tensorBoard: Type.CommandTask, + source: LaunchTensorBoardSearchesParams, +): boolean => { + if (source.searchIds) { + source.searchIds?.sort(); + tensorBoard.misc?.experimentIds?.sort(); + + if (_.isEqual(tensorBoard.misc?.experimentIds, source.searchIds)) { + return true; + } + } + + return false; +}; + const commandStateSortOrder: CommandState[] = [ CommandState.Pulling, CommandState.Starting,