mas_config/sections/
upstream_oauth2.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::BTreeMap;
8
9use camino::Utf8PathBuf;
10use mas_iana::jose::JsonWebSignatureAlg;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::skip_serializing_none;
14use ulid::Ulid;
15use url::Url;
16
17use crate::ConfigurationSection;
18
19/// Upstream OAuth 2.0 providers configuration
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
21pub struct UpstreamOAuth2Config {
22    /// List of OAuth 2.0 providers
23    pub providers: Vec<Provider>,
24}
25
26impl UpstreamOAuth2Config {
27    /// Returns true if the configuration is the default one
28    pub(crate) fn is_default(&self) -> bool {
29        self.providers.is_empty()
30    }
31}
32
33impl ConfigurationSection for UpstreamOAuth2Config {
34    const PATH: Option<&'static str> = Some("upstream_oauth2");
35
36    fn validate(
37        &self,
38        figment: &figment::Figment,
39    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
40        for (index, provider) in self.providers.iter().enumerate() {
41            let annotate = |mut error: figment::Error| {
42                error.metadata = figment
43                    .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
44                    .cloned();
45                error.profile = Some(figment::Profile::Default);
46                error.path = vec![
47                    Self::PATH.unwrap().to_owned(),
48                    "providers".to_owned(),
49                    index.to_string(),
50                ];
51                error
52            };
53
54            if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
55                && provider.issuer.is_none()
56            {
57                return Err(annotate(figment::Error::custom(
58                    "The `issuer` field is required when discovery is enabled",
59                ))
60                .into());
61            }
62
63            match provider.token_endpoint_auth_method {
64                TokenAuthMethod::None
65                | TokenAuthMethod::PrivateKeyJwt
66                | TokenAuthMethod::SignInWithApple => {
67                    if provider.client_secret.is_some() {
68                        return Err(annotate(figment::Error::custom(
69                            "Unexpected field `client_secret` for the selected authentication method",
70                        )).into());
71                    }
72                }
73                TokenAuthMethod::ClientSecretBasic
74                | TokenAuthMethod::ClientSecretPost
75                | TokenAuthMethod::ClientSecretJwt => {
76                    if provider.client_secret.is_none() {
77                        return Err(annotate(figment::Error::missing_field("client_secret")).into());
78                    }
79                }
80            }
81
82            match provider.token_endpoint_auth_method {
83                TokenAuthMethod::None
84                | TokenAuthMethod::ClientSecretBasic
85                | TokenAuthMethod::ClientSecretPost
86                | TokenAuthMethod::SignInWithApple => {
87                    if provider.token_endpoint_auth_signing_alg.is_some() {
88                        return Err(annotate(figment::Error::custom(
89                            "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
90                        )).into());
91                    }
92                }
93                TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
94                    if provider.token_endpoint_auth_signing_alg.is_none() {
95                        return Err(annotate(figment::Error::missing_field(
96                            "token_endpoint_auth_signing_alg",
97                        ))
98                        .into());
99                    }
100                }
101            }
102
103            match provider.token_endpoint_auth_method {
104                TokenAuthMethod::SignInWithApple => {
105                    if provider.sign_in_with_apple.is_none() {
106                        return Err(
107                            annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
108                        );
109                    }
110                }
111
112                _ => {
113                    if provider.sign_in_with_apple.is_some() {
114                        return Err(annotate(figment::Error::custom(
115                            "Unexpected field `sign_in_with_apple` for the selected authentication method",
116                        )).into());
117                    }
118                }
119            }
120        }
121
122        Ok(())
123    }
124}
125
126/// The response mode we ask the provider to use for the callback
127#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
128#[serde(rename_all = "snake_case")]
129pub enum ResponseMode {
130    /// `query`: The provider will send the response as a query string in the
131    /// URL search parameters
132    Query,
133
134    /// `form_post`: The provider will send the response as a POST request with
135    /// the response parameters in the request body
136    ///
137    /// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
138    FormPost,
139}
140
141/// Authentication methods used against the OAuth 2.0 provider
142#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
143#[serde(rename_all = "snake_case")]
144pub enum TokenAuthMethod {
145    /// `none`: No authentication
146    None,
147
148    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
149    /// authorization credentials
150    ClientSecretBasic,
151
152    /// `client_secret_post`: `client_id` and `client_secret` sent in the
153    /// request body
154    ClientSecretPost,
155
156    /// `client_secret_jwt`: a `client_assertion` sent in the request body and
157    /// signed using the `client_secret`
158    ClientSecretJwt,
159
160    /// `private_key_jwt`: a `client_assertion` sent in the request body and
161    /// signed by an asymmetric key
162    PrivateKeyJwt,
163
164    /// `sign_in_with_apple`: a special method for Signin with Apple
165    SignInWithApple,
166}
167
168/// How to handle a claim
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
170#[serde(rename_all = "lowercase")]
171pub enum ImportAction {
172    /// Ignore the claim
173    #[default]
174    Ignore,
175
176    /// Suggest the claim value, but allow the user to change it
177    Suggest,
178
179    /// Force the claim value, but don't fail if it is missing
180    Force,
181
182    /// Force the claim value, and fail if it is missing
183    Require,
184}
185
186impl ImportAction {
187    #[allow(clippy::trivially_copy_pass_by_ref)]
188    const fn is_default(&self) -> bool {
189        matches!(self, ImportAction::Ignore)
190    }
191}
192
193/// What should be done for the subject attribute
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
195pub struct SubjectImportPreference {
196    /// The Jinja2 template to use for the subject attribute
197    ///
198    /// If not provided, the default template is `{{ user.sub }}`
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub template: Option<String>,
201}
202
203impl SubjectImportPreference {
204    const fn is_default(&self) -> bool {
205        self.template.is_none()
206    }
207}
208
209/// What should be done for the localpart attribute
210#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
211pub struct LocalpartImportPreference {
212    /// How to handle the attribute
213    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
214    pub action: ImportAction,
215
216    /// The Jinja2 template to use for the localpart attribute
217    ///
218    /// If not provided, the default template is `{{ user.preferred_username }}`
219    #[serde(default, skip_serializing_if = "Option::is_none")]
220    pub template: Option<String>,
221}
222
223impl LocalpartImportPreference {
224    const fn is_default(&self) -> bool {
225        self.action.is_default() && self.template.is_none()
226    }
227}
228
229/// What should be done for the displayname attribute
230#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
231pub struct DisplaynameImportPreference {
232    /// How to handle the attribute
233    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
234    pub action: ImportAction,
235
236    /// The Jinja2 template to use for the displayname attribute
237    ///
238    /// If not provided, the default template is `{{ user.name }}`
239    #[serde(default, skip_serializing_if = "Option::is_none")]
240    pub template: Option<String>,
241}
242
243impl DisplaynameImportPreference {
244    const fn is_default(&self) -> bool {
245        self.action.is_default() && self.template.is_none()
246    }
247}
248
249/// What should be done with the email attribute
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
251pub struct EmailImportPreference {
252    /// How to handle the claim
253    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
254    pub action: ImportAction,
255
256    /// The Jinja2 template to use for the email address attribute
257    ///
258    /// If not provided, the default template is `{{ user.email }}`
259    #[serde(default, skip_serializing_if = "Option::is_none")]
260    pub template: Option<String>,
261}
262
263impl EmailImportPreference {
264    const fn is_default(&self) -> bool {
265        self.action.is_default() && self.template.is_none()
266    }
267}
268
269/// What should be done for the account name attribute
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
271pub struct AccountNameImportPreference {
272    /// The Jinja2 template to use for the account name. This name is only used
273    /// for display purposes.
274    ///
275    /// If not provided, it will be ignored.
276    #[serde(default, skip_serializing_if = "Option::is_none")]
277    pub template: Option<String>,
278}
279
280impl AccountNameImportPreference {
281    const fn is_default(&self) -> bool {
282        self.template.is_none()
283    }
284}
285
286/// How claims should be imported
287#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
288pub struct ClaimsImports {
289    /// How to determine the subject of the user
290    #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
291    pub subject: SubjectImportPreference,
292
293    /// Import the localpart of the MXID
294    #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
295    pub localpart: LocalpartImportPreference,
296
297    /// Import the displayname of the user.
298    #[serde(
299        default,
300        skip_serializing_if = "DisplaynameImportPreference::is_default"
301    )]
302    pub displayname: DisplaynameImportPreference,
303
304    /// Import the email address of the user based on the `email` and
305    /// `email_verified` claims
306    #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
307    pub email: EmailImportPreference,
308
309    /// Set a human-readable name for the upstream account for display purposes
310    #[serde(
311        default,
312        skip_serializing_if = "AccountNameImportPreference::is_default"
313    )]
314    pub account_name: AccountNameImportPreference,
315}
316
317impl ClaimsImports {
318    const fn is_default(&self) -> bool {
319        self.subject.is_default()
320            && self.localpart.is_default()
321            && self.displayname.is_default()
322            && self.email.is_default()
323    }
324}
325
326/// How to discover the provider's configuration
327#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
328#[serde(rename_all = "snake_case")]
329pub enum DiscoveryMode {
330    /// Use OIDC discovery with strict metadata verification
331    #[default]
332    Oidc,
333
334    /// Use OIDC discovery with relaxed metadata verification
335    Insecure,
336
337    /// Use a static configuration
338    Disabled,
339}
340
341impl DiscoveryMode {
342    #[allow(clippy::trivially_copy_pass_by_ref)]
343    const fn is_default(&self) -> bool {
344        matches!(self, DiscoveryMode::Oidc)
345    }
346}
347
348/// Whether to use proof key for code exchange (PKCE) when requesting and
349/// exchanging the token.
350#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
351#[serde(rename_all = "snake_case")]
352pub enum PkceMethod {
353    /// Use PKCE if the provider supports it
354    ///
355    /// Defaults to no PKCE if provider discovery is disabled
356    #[default]
357    Auto,
358
359    /// Always use PKCE with the S256 challenge method
360    Always,
361
362    /// Never use PKCE
363    Never,
364}
365
366impl PkceMethod {
367    #[allow(clippy::trivially_copy_pass_by_ref)]
368    const fn is_default(&self) -> bool {
369        matches!(self, PkceMethod::Auto)
370    }
371}
372
373fn default_true() -> bool {
374    true
375}
376
377#[allow(clippy::trivially_copy_pass_by_ref)]
378fn is_default_true(value: &bool) -> bool {
379    *value
380}
381
382#[allow(clippy::ref_option)]
383fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
384    *signed_response_alg == signed_response_alg_default()
385}
386
387#[allow(clippy::unnecessary_wraps)]
388fn signed_response_alg_default() -> JsonWebSignatureAlg {
389    JsonWebSignatureAlg::Rs256
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
393pub struct SignInWithApple {
394    /// The private key file used to sign the `id_token`
395    #[serde(skip_serializing_if = "Option::is_none")]
396    #[schemars(with = "Option<String>")]
397    pub private_key_file: Option<Utf8PathBuf>,
398
399    /// The private key used to sign the `id_token`
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub private_key: Option<String>,
402
403    /// The Team ID of the Apple Developer Portal
404    pub team_id: String,
405
406    /// The key ID of the Apple Developer Portal
407    pub key_id: String,
408}
409
410fn default_scope() -> String {
411    "openid".to_owned()
412}
413
414fn is_default_scope(scope: &str) -> bool {
415    scope == default_scope()
416}
417
418/// What to do when receiving an OIDC Backchannel logout request.
419#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
420#[serde(rename_all = "snake_case")]
421pub enum OnBackchannelLogout {
422    /// Do nothing
423    #[default]
424    DoNothing,
425
426    /// Only log out the MAS 'browser session' started by this OIDC session
427    LogoutBrowserOnly,
428
429    /// Log out all sessions started by this OIDC session, including MAS
430    /// 'browser sessions' and client sessions
431    LogoutAll,
432}
433
434impl OnBackchannelLogout {
435    #[allow(clippy::trivially_copy_pass_by_ref)]
436    const fn is_default(&self) -> bool {
437        matches!(self, OnBackchannelLogout::DoNothing)
438    }
439}
440
441/// Configuration for one upstream OAuth 2 provider.
442#[skip_serializing_none]
443#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
444pub struct Provider {
445    /// Whether this provider is enabled.
446    ///
447    /// Defaults to `true`
448    #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
449    pub enabled: bool,
450
451    /// An internal unique identifier for this provider
452    #[schemars(
453        with = "String",
454        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
455        description = "A ULID as per https://github.com/ulid/spec"
456    )]
457    pub id: Ulid,
458
459    /// The ID of the provider that was used by Synapse.
460    /// In order to perform a Synapse-to-MAS migration, this must be specified.
461    ///
462    /// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
463    ///
464    /// ### For `oidc_providers`:
465    /// This should be specified as `oidc-` followed by the ID that was
466    /// configured as `idp_id` in one of the `oidc_providers` in the Synapse
467    /// configuration.
468    /// For example, if Synapse's configuration contained `idp_id: wombat` for
469    /// this provider, then specify `oidc-wombat` here.
470    ///
471    /// ### For `oidc_config` (legacy):
472    /// Specify `oidc` here.
473    #[serde(skip_serializing_if = "Option::is_none")]
474    pub synapse_idp_id: Option<String>,
475
476    /// The OIDC issuer URL
477    ///
478    /// This is required if OIDC discovery is enabled (which is the default)
479    #[serde(skip_serializing_if = "Option::is_none")]
480    pub issuer: Option<String>,
481
482    /// A human-readable name for the provider, that will be shown to users
483    #[serde(skip_serializing_if = "Option::is_none")]
484    pub human_name: Option<String>,
485
486    /// A brand identifier used to customise the UI, e.g. `apple`, `google`,
487    /// `github`, etc.
488    ///
489    /// Values supported by the default template are:
490    ///
491    ///  - `apple`
492    ///  - `google`
493    ///  - `facebook`
494    ///  - `github`
495    ///  - `gitlab`
496    ///  - `twitter`
497    ///  - `discord`
498    #[serde(skip_serializing_if = "Option::is_none")]
499    pub brand_name: Option<String>,
500
501    /// The client ID to use when authenticating with the provider
502    pub client_id: String,
503
504    /// The client secret to use when authenticating with the provider
505    ///
506    /// Used by the `client_secret_basic`, `client_secret_post`, and
507    /// `client_secret_jwt` methods
508    #[serde(skip_serializing_if = "Option::is_none")]
509    pub client_secret: Option<String>,
510
511    /// The method to authenticate the client with the provider
512    pub token_endpoint_auth_method: TokenAuthMethod,
513
514    /// Additional parameters for the `sign_in_with_apple` method
515    #[serde(skip_serializing_if = "Option::is_none")]
516    pub sign_in_with_apple: Option<SignInWithApple>,
517
518    /// The JWS algorithm to use when authenticating the client with the
519    /// provider
520    ///
521    /// Used by the `client_secret_jwt` and `private_key_jwt` methods
522    #[serde(skip_serializing_if = "Option::is_none")]
523    pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
524
525    /// Expected signature for the JWT payload returned by the token
526    /// authentication endpoint.
527    ///
528    /// Defaults to `RS256`.
529    #[serde(
530        default = "signed_response_alg_default",
531        skip_serializing_if = "is_signed_response_alg_default"
532    )]
533    pub id_token_signed_response_alg: JsonWebSignatureAlg,
534
535    /// The scopes to request from the provider
536    ///
537    /// Defaults to `openid`.
538    #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
539    pub scope: String,
540
541    /// How to discover the provider's configuration
542    ///
543    /// Defaults to `oidc`, which uses OIDC discovery with strict metadata
544    /// verification
545    #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
546    pub discovery_mode: DiscoveryMode,
547
548    /// Whether to use proof key for code exchange (PKCE) when requesting and
549    /// exchanging the token.
550    ///
551    /// Defaults to `auto`, which uses PKCE if the provider supports it.
552    #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
553    pub pkce_method: PkceMethod,
554
555    /// Whether to fetch the user profile from the userinfo endpoint,
556    /// or to rely on the data returned in the `id_token` from the
557    /// `token_endpoint`.
558    ///
559    /// Defaults to `false`.
560    #[serde(default)]
561    pub fetch_userinfo: bool,
562
563    /// Expected signature for the JWT payload returned by the userinfo
564    /// endpoint.
565    ///
566    /// If not specified, the response is expected to be an unsigned JSON
567    /// payload.
568    #[serde(skip_serializing_if = "Option::is_none")]
569    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
570
571    /// The URL to use for the provider's authorization endpoint
572    ///
573    /// Defaults to the `authorization_endpoint` provided through discovery
574    #[serde(skip_serializing_if = "Option::is_none")]
575    pub authorization_endpoint: Option<Url>,
576
577    /// The URL to use for the provider's userinfo endpoint
578    ///
579    /// Defaults to the `userinfo_endpoint` provided through discovery
580    #[serde(skip_serializing_if = "Option::is_none")]
581    pub userinfo_endpoint: Option<Url>,
582
583    /// The URL to use for the provider's token endpoint
584    ///
585    /// Defaults to the `token_endpoint` provided through discovery
586    #[serde(skip_serializing_if = "Option::is_none")]
587    pub token_endpoint: Option<Url>,
588
589    /// The URL to use for getting the provider's public keys
590    ///
591    /// Defaults to the `jwks_uri` provided through discovery
592    #[serde(skip_serializing_if = "Option::is_none")]
593    pub jwks_uri: Option<Url>,
594
595    /// The response mode we ask the provider to use for the callback
596    #[serde(skip_serializing_if = "Option::is_none")]
597    pub response_mode: Option<ResponseMode>,
598
599    /// How claims should be imported from the `id_token` provided by the
600    /// provider
601    #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
602    pub claims_imports: ClaimsImports,
603
604    /// Additional parameters to include in the authorization request
605    ///
606    /// Orders of the keys are not preserved.
607    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
608    pub additional_authorization_parameters: BTreeMap<String, String>,
609
610    /// Whether the `login_hint` should be forwarded to the provider in the
611    /// authorization request.
612    ///
613    /// Defaults to `false`.
614    #[serde(default)]
615    pub forward_login_hint: bool,
616
617    /// What to do when receiving an OIDC Backchannel logout request.
618    ///
619    /// Defaults to "do_nothing".
620    #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
621    pub on_backchannel_logout: OnBackchannelLogout,
622}