Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,19 @@ def create(self) -> "SparkSession":
from pyspark.core.context import SparkContext

with self._lock:
# Build SparkConf from options
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, str(value))

sc = SparkContext.getOrCreate(sparkConf)
instantiated_session = SparkSession._instantiatedSession
# Get SparkContext
if (
instantiated_session is not None
and instantiated_session._sc._jsc is not None
):
sc = instantiated_session._sc
else:
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
# This SparkContext may be an existing one.
sc = SparkContext.getOrCreate(sparkConf)
jSparkSessionClass = SparkSession._get_j_spark_session_class(sc._jvm)
# Create a new SparkSession in the JVM
jSparkSession = jSparkSessionClass.builder().config(self._options).create()
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,32 @@ def test_create_sessions_share_spark_context(self):
finally:
session2.stop()

def test_create_does_not_construct_spark_conf_when_session_exists(self):
"""Ensure SparkConf() is not called when a valid session already exists."""
self.session = self._get_builder().create()
with unittest.mock.patch("pyspark.sql.session.SparkConf") as mock_spark_conf:
session2 = self._get_builder().create()
try:
mock_spark_conf.assert_not_called()
self.assertIs(session2.sparkContext, self.session.sparkContext)
finally:
session2.stop()

def test_create_applies_mutable_conf_to_second_session(self):
"""
Ensure that mutable SQL configs passed to create() are applied per-session
even when a valid SparkSession already exists.
"""
key = "spark.sql.shuffle.partitions"
self.session = self._get_builder().config(key, "5").create()
self.assertEqual(self.session.conf.get(key), "5")
session2 = self._get_builder().config(key, "7").create()
try:
self.assertEqual(session2.conf.get(key), "7")
self.assertIs(session2.sparkContext, self.session.sparkContext)
finally:
session2.stop()


class SparkSessionProfileTests(unittest.TestCase, PySparkErrorTestUtils):
def setUp(self):
Expand Down
Loading