diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 5bdd663c62934..22bbb73134314 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -610,12 +610,19 @@ def create(self) -> "SparkSession": from pyspark.core.context import SparkContext with self._lock: - # Build SparkConf from options - sparkConf = SparkConf() - for key, value in self._options.items(): - sparkConf.set(key, str(value)) - - sc = SparkContext.getOrCreate(sparkConf) + instantiated_session = SparkSession._instantiatedSession + # Get SparkContext + if ( + instantiated_session is not None + and instantiated_session._sc._jsc is not None + ): + sc = instantiated_session._sc + else: + sparkConf = SparkConf() + for key, value in self._options.items(): + sparkConf.set(key, value) + # This SparkContext may be an existing one. + sc = SparkContext.getOrCreate(sparkConf) jSparkSessionClass = SparkSession._get_j_spark_session_class(sc._jvm) # Create a new SparkSession in the JVM jSparkSession = jSparkSessionClass.builder().config(self._options).create() diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 3606056f6793d..fb86deb33a2da 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -616,6 +616,32 @@ def test_create_sessions_share_spark_context(self): finally: session2.stop() + def test_create_does_not_construct_spark_conf_when_session_exists(self): + """Ensure SparkConf() is not called when a valid session already exists.""" + self.session = self._get_builder().create() + with unittest.mock.patch("pyspark.sql.session.SparkConf") as mock_spark_conf: + session2 = self._get_builder().create() + try: + mock_spark_conf.assert_not_called() + self.assertIs(session2.sparkContext, self.session.sparkContext) + finally: + session2.stop() + + def test_create_applies_mutable_conf_to_second_session(self): + """ + Ensure that mutable SQL configs passed to create() are applied per-session + even when a valid SparkSession already exists. + """ + key = "spark.sql.shuffle.partitions" + self.session = self._get_builder().config(key, "5").create() + self.assertEqual(self.session.conf.get(key), "5") + session2 = self._get_builder().config(key, "7").create() + try: + self.assertEqual(session2.conf.get(key), "7") + self.assertIs(session2.sparkContext, self.session.sparkContext) + finally: + session2.stop() + class SparkSessionProfileTests(unittest.TestCase, PySparkErrorTestUtils): def setUp(self):