|
20 | 20 | import time |
21 | 21 | import base64 |
22 | 22 | import threading |
| 23 | +import re |
| 24 | +from dataclasses import dataclass |
23 | 25 |
|
24 | 26 | from google.protobuf.struct_pb2 import ListValue |
25 | 27 | from google.protobuf.struct_pb2 import Value |
@@ -196,6 +198,162 @@ def _datetime_to_rfc3339_nanoseconds(value): |
196 | 198 | return "{}.{}Z".format(value.isoformat(sep="T", timespec="seconds"), nanos) |
197 | 199 |
|
198 | 200 |
|
| 201 | +@dataclass |
| 202 | +class Interval: |
| 203 | + """Represents a Spanner INTERVAL type. |
| 204 | + |
| 205 | + An interval is a combination of months, days and nanoseconds. |
| 206 | + Internally, Spanner supports Interval value with the following range of individual fields: |
| 207 | + months: [-120000, 120000] |
| 208 | + days: [-3660000, 3660000] |
| 209 | + nanoseconds: [-316224000000000000000, 316224000000000000000] |
| 210 | + """ |
| 211 | + months: int = 0 |
| 212 | + days: int = 0 |
| 213 | + nanos: int = 0 |
| 214 | + |
| 215 | + def __str__(self) -> str: |
| 216 | + """Returns the ISO8601 duration format string representation.""" |
| 217 | + result = ["P"] |
| 218 | + |
| 219 | + # Handle years and months |
| 220 | + if self.months: |
| 221 | + is_negative = self.months < 0 |
| 222 | + abs_months = abs(self.months) |
| 223 | + years, months = divmod(abs_months, 12) |
| 224 | + if years: |
| 225 | + result.append(f"{'-' if is_negative else ''}{years}Y") |
| 226 | + if months: |
| 227 | + result.append(f"{'-' if is_negative else ''}{months}M") |
| 228 | + |
| 229 | + # Handle days |
| 230 | + if self.days: |
| 231 | + result.append(f"{self.days}D") |
| 232 | + |
| 233 | + # Handle time components |
| 234 | + if self.nanos: |
| 235 | + result.append("T") |
| 236 | + nanos = abs(self.nanos) |
| 237 | + is_negative = self.nanos < 0 |
| 238 | + |
| 239 | + # Convert to hours, minutes, seconds |
| 240 | + nanos_per_hour = 3600000000000 |
| 241 | + hours, nanos = divmod(nanos, nanos_per_hour) |
| 242 | + if hours: |
| 243 | + if is_negative: |
| 244 | + result.append("-") |
| 245 | + result.append(f"{hours}H") |
| 246 | + |
| 247 | + nanos_per_minute = 60000000000 |
| 248 | + minutes, nanos = divmod(nanos, nanos_per_minute) |
| 249 | + if minutes: |
| 250 | + if is_negative: |
| 251 | + result.append("-") |
| 252 | + result.append(f"{minutes}M") |
| 253 | + |
| 254 | + nanos_per_second = 1000000000 |
| 255 | + seconds, nanos_fraction = divmod(nanos, nanos_per_second) |
| 256 | + |
| 257 | + if seconds or nanos_fraction: |
| 258 | + if is_negative: |
| 259 | + result.append("-") |
| 260 | + if seconds: |
| 261 | + result.append(str(seconds)) |
| 262 | + elif nanos_fraction: |
| 263 | + result.append("0") |
| 264 | + |
| 265 | + if nanos_fraction: |
| 266 | + nano_str = f"{nanos_fraction:09d}" |
| 267 | + trimmed = nano_str.rstrip("0") |
| 268 | + if len(trimmed) <= 3: |
| 269 | + while len(trimmed) < 3: |
| 270 | + trimmed += "0" |
| 271 | + elif len(trimmed) <= 6: |
| 272 | + while len(trimmed) < 6: |
| 273 | + trimmed += "0" |
| 274 | + else: |
| 275 | + while len(trimmed) < 9: |
| 276 | + trimmed += "0" |
| 277 | + result.append(f".{trimmed}") |
| 278 | + result.append("S") |
| 279 | + |
| 280 | + if len(result) == 1: |
| 281 | + result.append("0Y") # Special case for zero interval |
| 282 | + |
| 283 | + return "".join(result) |
| 284 | + |
| 285 | + @classmethod |
| 286 | + def from_str(cls, s: str) -> 'Interval': |
| 287 | + """Parse an ISO8601 duration format string into an Interval.""" |
| 288 | + pattern = r'^P(-?\d+Y)?(-?\d+M)?(-?\d+D)?(T(-?\d+H)?(-?\d+M)?(-?((\d+([.,]\d{1,9})?)|([.,]\d{1,9}))S)?)?$' |
| 289 | + match = re.match(pattern, s) |
| 290 | + if not match or len(s) == 1: |
| 291 | + raise ValueError(f"Invalid interval format: {s}") |
| 292 | + |
| 293 | + parts = match.groups() |
| 294 | + if not any(parts[:3]) and not parts[3]: |
| 295 | + raise ValueError(f"Invalid interval format: at least one component (Y/M/D/H/M/S) is required: {s}") |
| 296 | + |
| 297 | + if parts[3] == "T" and not any(parts[4:7]): |
| 298 | + raise ValueError(f"Invalid interval format: time designator 'T' present but no time components specified: {s}") |
| 299 | + |
| 300 | + def parse_num(s: str, suffix: str) -> int: |
| 301 | + if not s: |
| 302 | + return 0 |
| 303 | + return int(s.rstrip(suffix)) |
| 304 | + |
| 305 | + years = parse_num(parts[0], "Y") |
| 306 | + months = parse_num(parts[1], "M") |
| 307 | + total_months = years * 12 + months |
| 308 | + |
| 309 | + days = parse_num(parts[2], "D") |
| 310 | + |
| 311 | + nanos = 0 |
| 312 | + if parts[3]: # Has time component |
| 313 | + # Convert hours to nanoseconds |
| 314 | + hours = parse_num(parts[4], "H") |
| 315 | + nanos += hours * 3600000000000 |
| 316 | + |
| 317 | + # Convert minutes to nanoseconds |
| 318 | + minutes = parse_num(parts[5], "M") |
| 319 | + nanos += minutes * 60000000000 |
| 320 | + |
| 321 | + # Handle seconds and fractional seconds |
| 322 | + if parts[6]: |
| 323 | + seconds = parts[6].rstrip("S") |
| 324 | + if "," in seconds: |
| 325 | + seconds = seconds.replace(",", ".") |
| 326 | + |
| 327 | + if "." in seconds: |
| 328 | + sec_parts = seconds.split(".") |
| 329 | + whole_seconds = sec_parts[0] if sec_parts[0] else "0" |
| 330 | + nanos += int(whole_seconds) * 1000000000 |
| 331 | + frac = sec_parts[1][:9].ljust(9, "0") |
| 332 | + frac_nanos = int(frac) |
| 333 | + if seconds.startswith("-"): |
| 334 | + frac_nanos = -frac_nanos |
| 335 | + nanos += frac_nanos |
| 336 | + else: |
| 337 | + nanos += int(seconds) * 1000000000 |
| 338 | + |
| 339 | + return cls(months=total_months, days=days, nanos=nanos) |
| 340 | + |
| 341 | + |
| 342 | +@dataclass |
| 343 | +class NullInterval: |
| 344 | + """Represents a Spanner INTERVAL that may be NULL.""" |
| 345 | + interval: Interval |
| 346 | + valid: bool = True |
| 347 | + |
| 348 | + def is_null(self) -> bool: |
| 349 | + return not self.valid |
| 350 | + |
| 351 | + def __str__(self) -> str: |
| 352 | + if not self.valid: |
| 353 | + return "NULL" |
| 354 | + return str(self.interval) |
| 355 | + |
| 356 | + |
199 | 357 | def _make_value_pb(value): |
200 | 358 | """Helper for :func:`_make_list_value_pbs`. |
201 | 359 |
|
@@ -251,6 +409,12 @@ def _make_value_pb(value): |
251 | 409 | return Value(null_value="NULL_VALUE") |
252 | 410 | else: |
253 | 411 | return Value(string_value=base64.b64encode(value)) |
| 412 | + if isinstance(value, Interval): |
| 413 | + return Value(string_value=str(value)) |
| 414 | + if isinstance(value, NullInterval): |
| 415 | + if value.is_null(): |
| 416 | + return Value(null_value="NULL_VALUE") |
| 417 | + return Value(string_value=str(value.interval)) |
254 | 418 |
|
255 | 419 | raise ValueError("Unknown type: %s" % (value,)) |
256 | 420 |
|
@@ -367,6 +531,8 @@ def _get_type_decoder(field_type, field_name, column_info=None): |
367 | 531 | for item_field in field_type.struct_type.fields |
368 | 532 | ] |
369 | 533 | return lambda value_pb: _parse_struct(value_pb, element_decoders) |
| 534 | + elif type_code == TypeCode.INTERVAL: |
| 535 | + return _parse_interval |
370 | 536 | else: |
371 | 537 | raise ValueError("Unknown type: %s" % (field_type,)) |
372 | 538 |
|
@@ -473,6 +639,13 @@ def _parse_nullable(value_pb, decoder): |
473 | 639 | return decoder(value_pb) |
474 | 640 |
|
475 | 641 |
|
| 642 | +def _parse_interval(value_pb): |
| 643 | + """Parse a Value protobuf containing an interval.""" |
| 644 | + if hasattr(value_pb, 'string_value'): |
| 645 | + return Interval.from_str(value_pb.string_value) |
| 646 | + return Interval.from_str(value_pb) |
| 647 | + |
| 648 | + |
476 | 649 | class _SessionWrapper(object): |
477 | 650 | """Base class for objects wrapping a session. |
478 | 651 |
|
|
0 commit comments