Skip to content

Commit 77e112b

Browse files
authored
Merge pull request #382 from AvaCodeSolutions/feat/380/improve-google-auth-group-enrollment
feat: #380 improve google auth group enrollment
2 parents d02bf5b + b02f11d commit 77e112b

19 files changed

Lines changed: 957 additions & 246 deletions

File tree

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Generated by Django 6.0.4 on 2026-04-27 06:05
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("django_email_learning", "0013_contentdelivery_reminder_state"),
9+
]
10+
11+
operations = [
12+
migrations.AddField(
13+
model_name="learner",
14+
name="photo",
15+
field=models.ImageField(blank=True, null=True, upload_to="learner_photos/"),
16+
),
17+
migrations.AlterField(
18+
model_name="jobexecution",
19+
name="job_name",
20+
field=models.CharField(
21+
choices=[
22+
("CHECK_IMAP", "check_imap"),
23+
("DELIVER_CONTENTS", "deliver_contents"),
24+
("SEND_REMINDERS", "send_reminders"),
25+
("DEACTIVATE_ENROLLMENTS", "deactivate_enrollments"),
26+
],
27+
max_length=200,
28+
),
29+
),
30+
]

django_email_learning/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ class Learner(models.Model):
539539
organization = models.ForeignKey(Organization, on_delete=models.CASCADE)
540540
email = models.EmailField()
541541
created_at = models.DateTimeField(auto_now_add=True)
542+
photo = models.ImageField(upload_to="learner_photos/", null=True, blank=True)
542543

543544
def save(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
544545
self.email = self.email.lower()

django_email_learning/oauth_integrations/group_enrollment/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from abc import abstractmethod, ABC
2+
from pydantic import BaseModel
3+
4+
5+
class Group(BaseModel):
6+
id: str
7+
name: str
8+
9+
10+
class User(BaseModel):
11+
email: str
12+
photo_path: str | None = None
13+
14+
def __hash__(self) -> int:
15+
return hash(self.email)
16+
17+
def __eq__(self, other: object) -> bool:
18+
if not isinstance(other, User):
19+
return NotImplemented
20+
return self.email == other.email
21+
22+
23+
class BaseGroupEnrollmentHandler(ABC, BaseModel):
24+
provider_and_purpose: str
25+
course_id: int
26+
state: str | None = None
27+
code: str | None = None
28+
29+
@abstractmethod
30+
def handle_redirect(self) -> str:
31+
"""
32+
Handles the OAuth redirect and returns the access_token
33+
"""
34+
raise NotImplementedError(
35+
"Subclasses must implement the handle_redirect method"
36+
)
37+
38+
@abstractmethod
39+
def get_authorization_url(self, state: str) -> str:
40+
"""
41+
Returns the authorization URL to redirect the user to for OAuth authentication
42+
"""
43+
raise NotImplementedError(
44+
"Subclasses must implement the get_authorization_url method"
45+
)
46+
47+
@abstractmethod
48+
def get_groups(self) -> list[Group] | None:
49+
"""
50+
List the groups that exists in an organization. If the external system does not have a concept of groups, return None.
51+
"""
52+
raise NotImplementedError("Subclasses must implement the get_groups method")
53+
54+
@abstractmethod
55+
def get_users_to_enroll(self, groups: set[str]) -> set[User]:
56+
"""
57+
Enrolls the users in the specified course based on the groups they belong to. If groups is None, enroll all users.
58+
"""
59+
raise NotImplementedError("Subclasses must implement the enroll_user method")
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import logging
2+
import base64
3+
4+
from django_email_learning.oauth_integrations.models import Session
5+
from django_email_learning.services.jwt_service import decode_jwt
6+
from .base_group_enrollment_handler import BaseGroupEnrollmentHandler, Group, User
7+
from google_auth_oauthlib.flow import Flow # type: ignore
8+
from django.conf import settings
9+
from django.urls import reverse
10+
from django.utils import timezone
11+
from django.core.files.base import ContentFile
12+
from django.core.files.storage import default_storage
13+
from urllib import error, parse, request as urlrequest
14+
from typing import Literal
15+
import json
16+
17+
18+
DJANGO_EMAIL_LEARNING_SETTINGS: dict = getattr(settings, "DJANGO_EMAIL_LEARNING", {})
19+
20+
21+
class GoogleGroupEnrollmentHandler(BaseGroupEnrollmentHandler):
22+
provider_and_purpose: Literal["google_group_enrollment"] = "google_group_enrollment"
23+
code_verifier: str | None = None
24+
25+
def _build_flow(self) -> Flow:
26+
flow = Flow.from_client_config(
27+
client_config={
28+
"web": {
29+
"client_id": DJANGO_EMAIL_LEARNING_SETTINGS.get(
30+
"GOOGLE_OAUTH_CLIENT_ID"
31+
),
32+
"client_secret": DJANGO_EMAIL_LEARNING_SETTINGS.get(
33+
"GOOGLE_OAUTH_CLIENT_SECRET"
34+
),
35+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
36+
"token_uri": "https://oauth2.googleapis.com/token",
37+
}
38+
},
39+
scopes=[
40+
"https://www.googleapis.com/auth/admin.directory.user.readonly",
41+
"https://www.googleapis.com/auth/admin.directory.group.readonly",
42+
],
43+
state=self.state,
44+
)
45+
flow.redirect_uri = DJANGO_EMAIL_LEARNING_SETTINGS.get(
46+
"SITE_BASE_URL", "http://localhost:8000"
47+
) + reverse("django_email_learning:oauth_integrations:redirect_view")
48+
return flow
49+
50+
def get_authorization_url(self, state: str) -> str:
51+
self.state = state
52+
flow = self._build_flow()
53+
authorization_url, _ = flow.authorization_url(
54+
access_type="offline",
55+
include_granted_scopes="true",
56+
prompt="consent",
57+
code_challenge_method="S256",
58+
)
59+
self.code_verifier = flow.code_verifier
60+
return authorization_url
61+
62+
def handle_redirect(self) -> str:
63+
if not self.code or not self.state:
64+
raise ValueError(
65+
"Authorization code and state are required to enroll from Google Directory"
66+
)
67+
if not self.code_verifier:
68+
raise ValueError(
69+
"Code verifier is required to enroll from Google Directory"
70+
)
71+
72+
flow = self._build_flow()
73+
flow.code_verifier = self.code_verifier
74+
75+
flow.fetch_token(code=self.code)
76+
credentials = flow.credentials
77+
78+
access_token = credentials.token
79+
80+
if not access_token:
81+
raise ValueError("Unable to retrieve access token from Google OAuth flow")
82+
83+
return access_token
84+
85+
def get_groups(self) -> list[Group] | None:
86+
session = Session.objects.filter(session_id=self.state).first()
87+
if not session:
88+
raise ValueError("Session not found for state: {}".format(self.state))
89+
if not session.access_token:
90+
raise ValueError(
91+
"Access token not found in session for state: {}".format(self.state)
92+
)
93+
94+
access_token = decode_jwt(session.access_token)["access_token"]
95+
url = (
96+
"https://www.googleapis.com/admin/directory/v1/groups?customer=my_customer"
97+
)
98+
req = urlrequest.Request(
99+
url,
100+
headers={
101+
"Authorization": f"Bearer {access_token}",
102+
"Accept": "application/json",
103+
},
104+
method="GET",
105+
)
106+
107+
try:
108+
with urlrequest.urlopen(req) as response:
109+
payload = json.loads(response.read().decode("utf-8"))
110+
except error.HTTPError as e:
111+
raise ValueError(f"Google Directory API request failed: {e}") from e
112+
113+
groups = payload.get("groups", [])
114+
return [Group(id=group["id"], name=group["name"]) for group in groups]
115+
116+
def get_users_to_enroll(self, groups: set[str] | None) -> set[User]:
117+
user_ids: set[str] = set()
118+
users: set[User] = set()
119+
session_id = self.state
120+
session = Session.objects.filter(session_id=session_id).first()
121+
if not session:
122+
raise ValueError(f"Session not found for state: {session_id}")
123+
124+
if not session.access_token:
125+
raise ValueError(
126+
f"Access token not found in session for state: {session_id}"
127+
)
128+
129+
access_token = decode_jwt(session.access_token)["access_token"]
130+
131+
if (groups is not None and "all" in groups) or groups is None:
132+
user_ids.update(self._get_user_id_for_all(access_token))
133+
else:
134+
user_ids.update(self._get_user_id_for_groups(access_token, groups))
135+
136+
for user_id in user_ids:
137+
user = self._get_user(user_id)
138+
if user:
139+
users.add(user)
140+
141+
return users
142+
143+
def _get_user_id_for_groups(self, access_token: str, groups: set[str]) -> set[str]:
144+
user_ids = set()
145+
for group_id in groups:
146+
page_token: str | None = None
147+
url = f"https://www.googleapis.com/admin/directory/v1/groups/{group_id}/members"
148+
while True:
149+
query = {
150+
"maxResults": "500",
151+
}
152+
if page_token:
153+
query["pageToken"] = page_token
154+
155+
url_with_query = f"{url}?{parse.urlencode(query)}"
156+
req = urlrequest.Request(
157+
url_with_query,
158+
headers={
159+
"Authorization": f"Bearer {access_token}",
160+
"Accept": "application/json",
161+
},
162+
method="GET",
163+
)
164+
165+
try:
166+
with urlrequest.urlopen(req) as response:
167+
payload = json.loads(response.read().decode("utf-8"))
168+
except error.HTTPError as e:
169+
raise ValueError(
170+
f"Google Directory API request failed {e.code}: {e.reason}"
171+
)
172+
173+
members = payload.get("members", [])
174+
for member in members:
175+
if (
176+
member.get("type") == "USER"
177+
and member.get("status") == "ACTIVE"
178+
):
179+
user_ids.add(member["id"])
180+
page_token = payload.get("nextPageToken")
181+
if not page_token:
182+
break
183+
184+
return user_ids
185+
186+
def _get_user_id_for_all(self, access_token: str) -> set[str]:
187+
user_ids = set()
188+
page_token: str | None = None
189+
users_endpoint = "https://admin.googleapis.com/admin/directory/v1/users"
190+
while True:
191+
query = {
192+
"customer": "my_customer",
193+
"orderBy": "email",
194+
"maxResults": 500,
195+
}
196+
if page_token:
197+
query["pageToken"] = page_token
198+
url = f"{users_endpoint}?{parse.urlencode(query)}"
199+
200+
req = urlrequest.Request(
201+
url,
202+
headers={
203+
"Authorization": f"Bearer {access_token}",
204+
"Accept": "application/json",
205+
},
206+
method="GET",
207+
)
208+
209+
try:
210+
with urlrequest.urlopen(req) as response:
211+
payload = json.loads(response.read().decode("utf-8"))
212+
except error.HTTPError as e:
213+
raise ValueError(f"Google Directory API request failed: {e}") from e
214+
215+
users = payload.get("users", [])
216+
for user in users:
217+
email = user.get("primaryEmail")
218+
is_archived = user.get("archived", False)
219+
is_suspended = user.get("suspended", False)
220+
if not email or is_archived or is_suspended:
221+
continue
222+
user_ids.add(user.get("id"))
223+
224+
page_token = payload.get("nextPageToken")
225+
if not page_token:
226+
break
227+
228+
return user_ids
229+
230+
def _get_user(self, user_id: str) -> User | None:
231+
user_endpoint = (
232+
f"https://admin.googleapis.com/admin/directory/v1/users/{user_id}"
233+
)
234+
session = Session.objects.get(session_id=self.state)
235+
if not session:
236+
raise ValueError(f"Session not found for state: {self.state}")
237+
if not session.access_token:
238+
raise ValueError(
239+
f"Access token not found in session for state: {self.state}"
240+
)
241+
req = urlrequest.Request(
242+
user_endpoint,
243+
headers={
244+
"Authorization": f"Bearer {decode_jwt(session.access_token)['access_token']}",
245+
"Accept": "application/json",
246+
},
247+
method="GET",
248+
)
249+
with urlrequest.urlopen(req) as response:
250+
payload = json.loads(response.read().decode("utf-8"))
251+
252+
email = payload.get("primaryEmail")
253+
254+
if not email:
255+
return None
256+
257+
photo_endpoint = f"https://www.googleapis.com/admin/directory/v1/users/{user_id}/photos/thumbnail"
258+
photo_req = urlrequest.Request(
259+
photo_endpoint,
260+
headers={
261+
"Authorization": f"Bearer {decode_jwt(session.access_token)['access_token']}",
262+
"Accept": "application/json",
263+
},
264+
method="GET",
265+
)
266+
try:
267+
with urlrequest.urlopen(photo_req) as response:
268+
photo_payload = json.loads(response.read().decode("utf-8"))
269+
data = photo_payload.get("photoData")
270+
mime_type = photo_payload.get("mimeType")
271+
file_name = f"{user_id}_photo.{mime_type.lower().replace('image/', '')}"
272+
273+
if data and mime_type:
274+
date_prefix = timezone.now().strftime("%Y%m%d")
275+
padded_data = data + "=" * (-len(data) % 4)
276+
decoded_photo = base64.urlsafe_b64decode(padded_data)
277+
278+
file_path = default_storage.save(
279+
f"uploads/{date_prefix}/{self.course_id}/{file_name}",
280+
ContentFile(decoded_photo),
281+
)
282+
print(file_path)
283+
return User(email=email, photo_path=file_path)
284+
except error.HTTPError:
285+
logging.warning(
286+
f"Failed to retrieve photo for user {email} with id {user_id}"
287+
)
288+
289+
return User(email=email, photo_path=None)

0 commit comments

Comments
 (0)