diff --git a/reactgrid/lib/utils/getFillDirection.ts b/reactgrid/lib/utils/getFillDirection.ts index 0db42fad..d7b88957 100644 --- a/reactgrid/lib/utils/getFillDirection.ts +++ b/reactgrid/lib/utils/getFillDirection.ts @@ -20,7 +20,7 @@ export const getFillDirection = ( const currectFocusedCell = store.getCellByIndexes(store.focusedLocation.rowIndex, store.focusedLocation.colIndex); - if (pointerColIdx === -1) return undefined; + if (pointerRowIdx === -1 || pointerColIdx === -1) return undefined; const selectedArea = store.selectedArea; @@ -35,31 +35,41 @@ export const getFillDirection = ( cellArea = EMPTY_AREA; } - const differences: { direction: FillDirection; value: number | null }[] = []; - - differences.push({ direction: "", value: null }); - - differences.push({ - direction: "up", - value: pointerRowIdx < cellArea.startRowIdx ? pointerRowIdx : null, - }); - - differences.push({ - direction: "down", - value: pointerRowIdx >= cellArea.endRowIdx ? pointerRowIdx : null, - }); - - differences.push({ - direction: "left", - value: pointerColIdx < cellArea.startColIdx ? pointerColIdx : null, - }); - - differences.push({ - direction: "right", - value: pointerColIdx >= cellArea.endColIdx ? pointerColIdx : null, - }); - - if (differences.every((diff) => diff.value === null)) return differences[0]; - - return differences.reduce((prev, current) => (current.value !== null ? current : prev)); + const bottomDiff = pointerRowIdx >= cellArea.endRowIdx ? Math.abs(pointerRowIdx + 1 - cellArea.endRowIdx) : 0; + const topDiff = pointerRowIdx <= cellArea.startRowIdx ? Math.abs(pointerRowIdx - cellArea.startRowIdx) : 0; + const rightDiff = pointerColIdx >= cellArea.endColIdx ? Math.abs(pointerColIdx + 1 - cellArea.endColIdx) : 0; + const leftDiff = pointerColIdx <= cellArea.startColIdx ? Math.abs(pointerColIdx - cellArea.startColIdx) : 0; + + if (pointerRowIdx >= cellArea.endRowIdx) { + if (bottomDiff >= rightDiff && bottomDiff >= leftDiff) { + return { + direction: "down", + value: pointerRowIdx, + }; + } + } + if (pointerRowIdx <= cellArea.startRowIdx) { + if (topDiff >= rightDiff && topDiff >= leftDiff) { + return { + direction: "up", + value: pointerRowIdx, + }; + } + } + if (pointerColIdx >= cellArea.endColIdx) { + if (rightDiff > topDiff && rightDiff > bottomDiff) { + return { + direction: "right", + value: pointerColIdx, + }; + } + } + if (pointerColIdx <= cellArea.startColIdx) { + if (leftDiff > topDiff && leftDiff > bottomDiff) { + return { + direction: "left", + value: pointerColIdx, + }; + } + } };