ampelmaennchen.git

ref: master

server/users.go


  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package server

import (
	"fmt"

	"apiote.xyz/p/ampelmaennchen/accounts"
	"apiote.xyz/p/ampelmaennchen/db"

	"encoding/json"
	"log"
	"net/http"

	"go.mongodb.org/mongo-driver/v2/bson"
)

type oidcConfig struct {
	UserinfoEndpoint string `json:"userinfo_endpoint"`
}

type userInfo struct {
	Sub   string `json:"sub"`
	Email string `json:"email"`
}

type userInfoError struct {
	status int
	cause  error
}

func (e userInfoError) Error() string {
	return e.cause.Error()
}

func getUserInfo(authorization string) (userInfo, error) {
	client := http.Client{}
	issuer := "https://oauth-bimba.apiote.xyz"
	response, err := client.Get(issuer + "/.well-known/openid-configuration")
	if err != nil || response.StatusCode != 200 {
		return userInfo{}, userInfoError{
			status: http.StatusInternalServerError,
			cause:  fmt.Errorf("while getting OIDC configuration: %w", err),
		}
	}
	oidcConfig := oidcConfig{}
	decoder := json.NewDecoder(response.Body)
	decoder.Decode(&oidcConfig)

	userinfoRequest, err := http.NewRequest(http.MethodGet, oidcConfig.UserinfoEndpoint, nil)
	if err != nil {
		return userInfo{}, userInfoError{
			status: http.StatusInternalServerError,
			cause:  fmt.Errorf("while creating request to userinfo: %w", err),
		}
	}
	userinfoRequest.Header.Add("authorization", authorization)

	response, err = client.Do(userinfoRequest)
	if err != nil {
		return userInfo{}, userInfoError{
			status: http.StatusInternalServerError,
			cause:  fmt.Errorf("while performing request to userinfo: %w", err),
		}
	}

	if response.StatusCode == http.StatusForbidden || response.StatusCode == http.StatusUnauthorized {
		return userInfo{}, userInfoError{
			status: response.StatusCode,
			cause:  fmt.Errorf("%d from userinfo", response.StatusCode),
		}
	}

	userInfo := userInfo{}
	decoder = json.NewDecoder(response.Body)
	decoder.Decode(&userInfo)
	return userInfo, nil
}

func getUser(userID string, w http.ResponseWriter, r *http.Request) {
	userInfo, err := getUserInfo(r.Header.Get("authorization"))
	if err != nil {
		w.WriteHeader(err.(userInfoError).status)
		log.Printf("while getting user info: %v", err)
		return
	}
	currentSubscription, err := db.GetValidSubscription(userInfo.Sub)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while getting logged-in user subscription from db: %v", err)
		return
	}
	if !(currentSubscription.Plan == accounts.SEAT_GARAGE || userInfo.Sub == userID) {
		w.WriteHeader(http.StatusForbidden)
		return
	}
	user, err := db.GetUser(userID)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while getting user from db: %v", err)
		return
	}
	user.Email = userInfo.Email
	subscriptions, err := db.GetSubscriptions(userID)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while getting user's subscriptions from db: %v", err)
		return
	}
	user.Subscriptions = subscriptions

	marshalledUser, err := bson.Marshal(user)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while marshalling user: %v", err)
		return
	}
	w.Write(marshalledUser)
}

func validateTicket(userID string, w http.ResponseWriter, r *http.Request) {
	userInfo, err := getUserInfo(r.Header.Get("authorization"))
	if err != nil {
		w.WriteHeader(err.(userInfoError).status)
		log.Printf("while getting user info: %v", err)
		return
	}
	if !(userInfo.Sub == userID) {
		w.WriteHeader(http.StatusForbidden)
		log.Printf("user mismatch")
		return
	}
	r.ParseForm()

	ticket, err := accounts.ParseSubscription(r.Form.Get("ticket"))
	if err != nil {
		w.WriteHeader(http.StatusBadRequest)
		log.Printf("while parsing subscription: %v", err)
		return
	}

	isTicketSlotFree, err := db.IsTicketUnused(ticket)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while checking free ticket slot: %v", err)
		return
	}

	if !isTicketSlotFree {
		w.WriteHeader(http.StatusForbidden)
		log.Printf("ticket used")
		return
	}

	seed, err := db.GetSubscriptionKey()
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while getting seed: %v", err)
		return
	}

	ticketSignedOK, err := accounts.CheckSignature(ticket, seed)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while checking ticket signature: %v", err)
		return
	}

	if !ticketSignedOK {
		w.WriteHeader(http.StatusBadRequest)
		log.Printf("ticket signature incorrect")
		return
	}

	err = db.ValidateTicket(ticket, userID)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		log.Printf("while validating ticket: %v", err)
		return
	}
}