Skip to content
Merged
251 changes: 251 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,257 @@ def most_permissive_type(datatype: DataType) -> DataType:
return copy.deepcopy(datatype)


def format_year_month_interval_for_display(
cell: str, start_field: int, end_field: int
) -> str:
"""
Format a YearMonthIntervalType string for display in _show_string_spark().

Args:
cell: The string representation of the interval (e.g., "+1-6", "-2-03", "24")
start_field: Start field constant from YearMonthIntervalType (YEAR=0, MONTH=1)
end_field: End field constant from YearMonthIntervalType (YEAR=0, MONTH=1)

Returns:
Formatted interval string (e.g., "INTERVAL '1-6' YEAR TO MONTH", "INTERVAL '24' MONTH")
"""
# Handle different input formats
# Check for compound format (year-month) vs simple number
has_internal_dash = (cell.startswith("+") or cell.startswith("-")) and "-" in cell[
1:
]
Comment thread
sfc-gh-fhe marked this conversation as resolved.

# Default initialization
years = "0"
months = "0"
is_negative = False

if has_internal_dash:
# Format like "+1-03" or "-1-03" or "-1-6" (compound year-month)
is_negative = cell.startswith("-")

# Remove the sign prefix and parse the remaining "year-month" part
remaining = cell[1:] # Remove the "+" or "-" prefix: "1-6"
if "-" in remaining:
parts = remaining.split("-", 1) # Split only on first dash: ["1", "6"]
years = str(int(parts[0]))
months = str(int(parts[1]))

# Format based on start/end field
sign_prefix = "-" if is_negative else ""

if (
start_field == YearMonthIntervalType.YEAR
and end_field == YearMonthIntervalType.MONTH
):
# Full range: YEAR TO MONTH
return f"INTERVAL '{sign_prefix}{years}-{months}' YEAR TO MONTH"
elif (
start_field == YearMonthIntervalType.YEAR
and end_field == YearMonthIntervalType.YEAR
):
# Years only: YEAR
return f"INTERVAL '{sign_prefix}{years}' YEAR"
elif (
start_field == YearMonthIntervalType.MONTH
and end_field == YearMonthIntervalType.MONTH
):
# Months only: MONTH - calculate total months
total_months = int(years) * 12 + int(months)
if is_negative:
total_months = -total_months
return f"INTERVAL '{total_months}' MONTH"


def format_day_time_interval_for_display(
cell: Union[str, datetime.timedelta], start_field: int, end_field: int
) -> str:
"""
Format a DayTimeIntervalType value for display in _show_string_spark().

Args:
cell: Either a datetime.timedelta object or string representation
start_field: Start field constant from DayTimeIntervalType (DAY=0, HOUR=1, MINUTE=2, SECOND=3)
end_field: End field constant from DayTimeIntervalType (DAY=0, HOUR=1, MINUTE=2, SECOND=3)

Returns:
Formatted interval string (e.g., "INTERVAL '01:30:45' HOUR TO SECOND")
"""
if isinstance(cell, datetime.timedelta):
# Heuristic: Use Decimal for extreme values near 64-bit boundary, float for normal values
total_seconds_approx = cell.total_seconds()

# Check if we're approaching values where float precision becomes problematic
# Be conservative: use Decimal for large values to ensure precision
# This corresponds to roughly 3 million years - normal use cases are well below this
if (
abs(total_seconds_approx) > 1e11
): # ~100 gigaseconds, very conservative threshold
# Use Decimal arithmetic for precise conversion to avoid floating-point precision loss
total_seconds_value = (
decimal.Decimal(cell.days) * decimal.Decimal(86400)
+ decimal.Decimal(cell.seconds)
+ decimal.Decimal(cell.microseconds) / decimal.Decimal(1_000_000)
)
else:
# Use fast float path for normal values
total_seconds_value = cell.total_seconds()

interval_str = format_day_time_interval(
total_seconds_value, start_field, end_field
)
elif isinstance(cell, str):
# Raw string that needs to be formatted (e.g., "1 01:01:01.7878")
interval_str = cell

field_names = {
DayTimeIntervalType.DAY: "DAY",
DayTimeIntervalType.HOUR: "HOUR",
DayTimeIntervalType.MINUTE: "MINUTE",
DayTimeIntervalType.SECOND: "SECOND",
}

start_name = field_names.get(start_field, "DAY")
end_name = field_names.get(end_field, "SECOND")

if start_field == end_field:
return f"INTERVAL '{interval_str}' {start_name}"
else:
return f"INTERVAL '{interval_str}' {start_name} TO {end_name}"


def format_day_time_interval(
total_seconds_value: Union[float, decimal.Decimal], start_field: int, end_field: int
) -> str:
"""
Format a DayTimeIntervalType value for display in _show_string_spark().

Args:
total_seconds_value: Total seconds as either float or Decimal (can be negative)
start_field: Start field constant from DayTimeIntervalType (DAY=0, HOUR=1, MINUTE=2, SECOND=3)
end_field: End field constant from DayTimeIntervalType (DAY=0, HOUR=1, MINUTE=2, SECOND=3)

Returns:
Formatted interval string (e.g., "01:30:45", "2 12:30", "05", etc.)
"""
is_negative = total_seconds_value < 0
abs_total_seconds = abs(total_seconds_value)

# Determine if we're working with Decimal for high-precision arithmetic
use_decimal = isinstance(total_seconds_value, decimal.Decimal)

days = int(abs_total_seconds) // 86400
remaining_seconds = abs_total_seconds - (days * 86400)
hours = int(remaining_seconds) // 3600
remaining_after_hours = remaining_seconds - (hours * 3600)
minutes = int(remaining_after_hours) // 60

# Calculate seconds more precisely to avoid floating-point accumulation errors
# Use the original total and subtract the calculated day/hour/minute components
if use_decimal:
total_non_second_time = (
decimal.Decimal(days * 86400)
+ decimal.Decimal(hours * 3600)
+ decimal.Decimal(minutes * 60)
)
else:
total_non_second_time = (days * 86400) + (hours * 3600) + (minutes * 60)
seconds = abs_total_seconds - total_non_second_time

sign = "-" if is_negative else ""

def format_with_leading_zero(value: int) -> str:
"""Format integer with leading zero if < 10, otherwise as-is"""
return f"{value:02d}" if value < 10 else f"{value}"

def format_seconds_with_precision(
seconds_value: Union[float, decimal.Decimal]
) -> str:
"""Format seconds with full precision, preserving trailing zeros for proper padding"""
# Unified formatting logic for both Decimal and float types
if seconds_value == int(seconds_value):
return f"{int(seconds_value):02d}"
else:
# For fractional seconds, ensure proper leading zero padding
integer_part = int(seconds_value)
if integer_part < 10:
# Format with leading zero for the integer part
formatted = f"{seconds_value:.6f}".rstrip("0")
if formatted.endswith("."):
return f"{integer_part:02d}"
# Replace the integer part with zero-padded version
decimal_part = formatted.split(".", 1)[1]
return f"{integer_part:02d}.{decimal_part}"
else:
# For >= 10, use normal formatting
formatted = f"{seconds_value:.6f}".rstrip("0")
if formatted.endswith("."):
return f"{integer_part}"
return formatted

# For single field intervals, extract just that component
if start_field == end_field:
if start_field == DayTimeIntervalType.DAY:
return f"{sign}{days}"
elif start_field == DayTimeIntervalType.HOUR:
total_hours = int(abs_total_seconds) // 3600
return f"{sign}{format_with_leading_zero(total_hours)}"
elif start_field == DayTimeIntervalType.MINUTE:
total_minutes = int(abs_total_seconds) // 60
return f"{sign}{format_with_leading_zero(total_minutes)}"
elif start_field == DayTimeIntervalType.SECOND:
# Handle fractional seconds - use total seconds, not just remainder
if abs_total_seconds == int(abs_total_seconds):
total_secs_int = int(abs_total_seconds)
return f"{sign}{format_with_leading_zero(total_secs_int)}"
else:
# Use unified formatting that handles both float and Decimal
return f"{sign}{format_seconds_with_precision(abs_total_seconds)}"

# For multi-field intervals, format based on start/end fields
if start_field == DayTimeIntervalType.DAY:
hours_str = format_with_leading_zero(hours)
# DAY TO X format: truncate based on end_field
if end_field == DayTimeIntervalType.HOUR:
# DAY TO HOUR: "D HH"
return f"{sign}{days} {hours_str}"
elif end_field == DayTimeIntervalType.MINUTE:
# DAY TO MINUTE: "D HH:MM"
return f"{sign}{days} {hours_str}:{minutes:02d}"
else:
# DAY TO SECOND: "D HH:MM:SS"
if seconds == int(seconds):
return f"{sign}{days} {hours_str}:{minutes:02d}:{int(seconds):02d}"
else:
return f"{sign}{days} {hours_str}:{minutes:02d}:{format_seconds_with_precision(seconds)}"
elif start_field == DayTimeIntervalType.HOUR:
# HOUR TO X format: "HH:MM:SS" (no days)
total_hours = int(abs_total_seconds) // 3600
remaining_after_hours = abs_total_seconds - (total_hours * 3600)
mins = int(remaining_after_hours) // 60
secs = remaining_after_hours - (mins * 60)

if end_field == DayTimeIntervalType.MINUTE:
return f"{sign}{format_with_leading_zero(total_hours)}:{mins:02d}"
else: # TO SECOND
if secs == int(secs):
return f"{sign}{format_with_leading_zero(total_hours)}:{mins:02d}:{int(secs):02d}"
else:
return f"{sign}{format_with_leading_zero(total_hours)}:{mins:02d}:{format_seconds_with_precision(secs)}"
elif start_field == DayTimeIntervalType.MINUTE:
# MINUTE TO X format: "MM:SS" (no days or hours)
total_minutes = int(abs_total_seconds) // 60
remaining_secs = abs_total_seconds - (total_minutes * 60)

minutes_str = format_with_leading_zero(total_minutes)
if remaining_secs == int(remaining_secs):
return f"{sign}{minutes_str}:{int(remaining_secs):02d}"
else:
return (
f"{sign}{minutes_str}:{format_seconds_with_precision(remaining_secs)}"
)


# Type hints
ColumnOrName = Union["snowflake.snowpark.column.Column", str]
ColumnOrLiteralStr = Union["snowflake.snowpark.column.Column", str]
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@
ColumnOrName,
ColumnOrSqlExpr,
LiteralType,
format_day_time_interval_for_display,
format_year_month_interval_for_display,
snow_type_to_dtype_str,
type_string_to_type_object,
)
Expand Down Expand Up @@ -206,6 +208,7 @@
from snowflake.snowpark.types import (
ArrayType,
DataType,
DayTimeIntervalType,
MapType,
PandasDataFrameType,
StringType,
Expand All @@ -215,6 +218,7 @@
_FractionalType,
TimestampType,
TimestampTimeZone,
YearMonthIntervalType,
)

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
Expand Down Expand Up @@ -5115,6 +5119,20 @@ def format_timestamp_spark(dt: datetime.datetime) -> str:
res = "-Infinity"
else:
res = str(cell).replace("e+", "E").replace("e-", "E-")
elif isinstance(cell, str) and isinstance(datatype, YearMonthIntervalType):
start_field = getattr(
datatype, "start_field", YearMonthIntervalType.YEAR
)
end_field = getattr(datatype, "end_field", YearMonthIntervalType.MONTH)
res = format_year_month_interval_for_display(
cell, start_field, end_field
)
elif isinstance(cell, (str, datetime.timedelta)) and isinstance(
datatype, DayTimeIntervalType
):
start_field = getattr(datatype, "start_field", DayTimeIntervalType.DAY)
end_field = getattr(datatype, "end_field", DayTimeIntervalType.SECOND)
res = format_day_time_interval_for_display(cell, start_field, end_field)
else:
res = str(cell)
return res.replace("\n", "\\n")
Expand Down
Loading
Loading