diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 86f97ef55d534..69e751e6df425 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -394,7 +394,11 @@ def __abs__(self: IndexOpsLike) -> IndexOpsLike: # comparison operators def __eq__(self, other: Any) -> SeriesOrIndex: # type: ignore[override] - return self._dtype_op.eq(self, other) + # pandas always returns False for all items with dict and set. + if isinstance(other, (dict, set)): + return self != self + else: + return self._dtype_op.eq(self, other) def __ne__(self, other: Any) -> SeriesOrIndex: # type: ignore[override] return self._dtype_op.ne(self, other) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 79004322e1e61..47a6671823d0d 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -376,11 +376,98 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: raise TypeError(">= can not be applied to %s." % self.pretty_name) def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.base import column_op - - _sanitize_list_like(right) + if isinstance(right, (list, tuple)): + from pyspark.pandas.series import first_series, scol_for + from pyspark.pandas.frame import DataFrame + from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField + + len_right = len(right) + if len(left) != len(right): + raise ValueError("Lengths must be equal") + + sdf = left._internal.spark_frame + structed_scol = F.struct( + sdf[NATURAL_ORDER_COLUMN_NAME], + *left._internal.index_spark_columns, + left.spark.column + ) + # The size of the list is expected to be small. + collected_structed_scol = F.collect_list(structed_scol) + # Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order. + collected_structed_scol = F.array_sort(collected_structed_scol) + right_values_scol = F.array([F.lit(x) for x in right]) # type: ignore + index_scol_names = left._internal.index_spark_column_names + scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0]) + # Compare the values of left and right by using zip_with function. + cond = F.zip_with( + collected_structed_scol, + right_values_scol, + lambda x, y: F.struct( + *[ + x[index_scol_name].alias(index_scol_name) + for index_scol_name in index_scol_names + ], + F.when(x[scol_name].isNull() | y.isNull(), False) + .otherwise( + x[scol_name] == y, + ) + .alias(scol_name) + ), + ).alias(scol_name) + # 1. `sdf_new` here looks like the below (the first field of each set is Index): + # +----------------------------------------------------------+ + # |0 | + # +----------------------------------------------------------+ + # |[{0, false}, {1, true}, {2, false}, {3, true}, {4, false}]| + # +----------------------------------------------------------+ + sdf_new = sdf.select(cond) + # 2. `sdf_new` after the explode looks like the below: + # +----------+ + # | col| + # +----------+ + # |{0, false}| + # | {1, true}| + # |{2, false}| + # | {3, true}| + # |{4, false}| + # +----------+ + sdf_new = sdf_new.select(F.explode(scol_name)) + # 3. Here, the final `sdf_new` looks like the below: + # +-----------------+-----+ + # |__index_level_0__| 0| + # +-----------------+-----+ + # | 0|false| + # | 1| true| + # | 2|false| + # | 3| true| + # | 4|false| + # +-----------------+-----+ + sdf_new = sdf_new.select("col.*") + + index_spark_columns = [ + scol_for(sdf_new, index_scol_name) for index_scol_name in index_scol_names + ] + data_spark_columns = [scol_for(sdf_new, scol_name)] + + internal = left._internal.copy( + spark_frame=sdf_new, + index_spark_columns=index_spark_columns, + data_spark_columns=data_spark_columns, + index_fields=[ + InternalField.from_struct_field(index_field) + for index_field in sdf_new.select(index_spark_columns).schema.fields + ], + data_fields=[ + InternalField.from_struct_field( + sdf_new.select(data_spark_columns).schema.fields[0] + ) + ], + ) + return first_series(DataFrame(internal)) + else: + from pyspark.pandas.base import column_op - return column_op(Column.__eq__)(left, right) + return column_op(Column.__eq__)(left, right) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: from pyspark.pandas.base import column_op diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 9e2052544318a..f6defe4b23a5f 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -675,7 +675,7 @@ def rfloordiv(self, other: Any) -> "Series": koalas = CachedAccessor("koalas", PandasOnSparkSeriesMethods) # Comparison Operators - def eq(self, other: Any) -> bool: + def eq(self, other: Any) -> "Series": """ Compare if the current value is equal to the other. diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index cd5d8347b12cf..ba7f88efe4b7d 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -1845,6 +1845,29 @@ def _test_cov(self, pser1, pser2): pscov = psser1.cov(psser2, min_periods=3) self.assert_eq(pcov, pscov, almost=True) + def test_series_eq(self): + pser = pd.Series([1, 2, 3, 4, 5, 6], name="x") + psser = ps.from_pandas(pser) + + # other = Series + pandas_other = pd.Series([np.nan, 1, 3, 4, np.nan, 6], name="x") + pandas_on_spark_other = ps.from_pandas(pandas_other) + self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index()) + self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index()) + + # other = Series with different Index + pandas_other = pd.Series( + [np.nan, 1, 3, 4, np.nan, 6], index=[10, 20, 30, 40, 50, 60], name="x" + ) + pandas_on_spark_other = ps.from_pandas(pandas_other) + self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index()) + + # other = Index + pandas_other = pd.Index([np.nan, 1, 3, 4, np.nan, 6], name="x") + pandas_on_spark_other = ps.from_pandas(pandas_other) + self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index()) + self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index()) + class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils): @classmethod @@ -2039,6 +2062,20 @@ def test_combine_first(self): with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"): psdf1.combine_first(psdf2) + def test_series_eq(self): + pser = pd.Series([1, 2, 3, 4, 5, 6], name="x") + psser = ps.from_pandas(pser) + + others = ( + ps.Series([np.nan, 1, 3, 4, np.nan, 6], name="x"), + ps.Index([np.nan, 1, 3, 4, np.nan, 6], name="x"), + ) + for other in others: + with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"): + psser.eq(other) + with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"): + psser == other + if __name__ == "__main__": from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index aba27fa305f58..0ec8d7182cf2d 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -3071,6 +3071,56 @@ def _test_cov(self, pdf): pscov = psdf["s1"].cov(psdf["s2"], min_periods=4) self.assert_eq(pcov, pscov, almost=True) + def test_eq(self): + pser = pd.Series([1, 2, 3, 4, 5, 6], name="x") + psser = ps.from_pandas(pser) + + # other = Series + self.assert_eq(pser.eq(pser), psser.eq(psser)) + self.assert_eq(pser == pser, psser == psser) + + # other = dict + other = {1: None, 2: None, 3: None, 4: None, np.nan: None, 6: None} + self.assert_eq(pser.eq(other), psser.eq(other)) + self.assert_eq(pser == other, psser == other) + + # other = set + other = {1, 2, 3, 4, np.nan, 6} + self.assert_eq(pser.eq(other), psser.eq(other)) + self.assert_eq(pser == other, psser == other) + + # other = list + other = [np.nan, 1, 3, 4, np.nan, 6] + if LooseVersion(pd.__version__) >= LooseVersion("1.2"): + self.assert_eq(pser.eq(other), psser.eq(other).sort_index()) + self.assert_eq(pser == other, (psser == other).sort_index()) + else: + self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index()) + self.assert_eq((pser == other).rename("x"), (psser == other).sort_index()) + + # other = tuple + other = (np.nan, 1, 3, 4, np.nan, 6) + if LooseVersion(pd.__version__) >= LooseVersion("1.2"): + self.assert_eq(pser.eq(other), psser.eq(other).sort_index()) + self.assert_eq(pser == other, (psser == other).sort_index()) + else: + self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index()) + self.assert_eq((pser == other).rename("x"), (psser == other).sort_index()) + + # other = list with the different length + other = [np.nan, 1, 3, 4, np.nan] + with self.assertRaisesRegex(ValueError, "Lengths must be equal"): + psser.eq(other) + with self.assertRaisesRegex(ValueError, "Lengths must be equal"): + psser == other + + # other = tuple with the different length + other = (np.nan, 1, 3, 4, np.nan) + with self.assertRaisesRegex(ValueError, "Lengths must be equal"): + psser.eq(other) + with self.assertRaisesRegex(ValueError, "Lengths must be equal"): + psser == other + if __name__ == "__main__": from pyspark.pandas.tests.test_series import * # noqa: F401