forked from picoCTF/picoCTF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshell_servers.py
346 lines (271 loc) · 9.25 KB
/
shell_servers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""Module dealing with shell server integration."""
import json
import spur
import api
from api import PicoException
def get_server(sid):
"""
Return the server dict corresponding to the sid provided.
Args:
sid: the server id to lookup
Returns:
The server dict, or None if the server was not found
"""
db = api.db.get_conn()
return db.shell_servers.find_one({"sid": sid}, {"_id": 0})
def get_connection(sid):
"""
Connect to a shell server via SSH.
Args:
sid: the shell server ID
Returns:
spur SshShell connection object
Raises:
PicoException if cannot connect to the host, authenticate, or
run shell_manager successfully
"""
server = get_server(sid)
try:
shell = None
# default to keypath if provided
if server["keypath"] != "":
shell = spur.SshShell(
hostname=server["host"],
username=server["username"],
private_key_file=server["keypath"],
port=server["port"],
missing_host_key=spur.ssh.MissingHostKey.accept,
connect_timeout=2,
)
else:
shell = spur.SshShell(
hostname=server["host"],
username=server["username"],
password=server["password"],
port=server["port"],
missing_host_key=spur.ssh.MissingHostKey.accept,
connect_timeout=2,
)
shell.run(["echo", "connected"])
except spur.ssh.ConnectionError as e :
raise PicoException(
"Cannot connect to {}@{}:{}\n{}".format(
server["username"], server["host"], server["port"], e
)
)
return shell
def add_server(*ignore, name, host, port, username, password="", protocol, server_number, keypath=""):
"""
Add a shell server to the pool of servers.
Servers are automatically assigned a server_number based on the current
number of servers if not explicitly specified.
Kwargs:
name: display name
host: hostname
port: SSH port
username
password
protocol: HTTP or HTTPS
server_number
Returns:
sid of the newly created shell server
Raises:
PicoException: a shell server with this server_number already exists
"""
db = api.db.get_conn()
if not server_number:
server_number = db.shell_servers.count() + 1
if db.shell_servers.find_one({"server_number": server_number}) is not None:
raise PicoException(
"Shell server with this server_number " + "already exists.", status_code=409
)
sid = api.common.token()
db.shell_servers.insert_one(
{
"sid": sid,
"name": name,
"host": host,
"port": port,
"username": username,
"password": password,
"keypath": keypath,
"protocol": protocol,
"server_number": server_number,
}
)
return sid
def update_server(sid, updates):
"""
Update a shell server.
Args:
sid: the sid of the server to update
updates: dict of updated shell server fields
Returns:
sid of the updated server (unchanged), or
None if the provided sid was not found
Raises:
PicoException if attempting to set server_number to one already
in use by a different server
"""
db = api.db.get_conn()
# Make sure we are not duplicating a server number
if "server_number" in updates and db.shell_servers.find_one(
{"server_number": updates["server_number"], "sid": {"$ne": sid}}
):
raise PicoException(
"Another shell server with this server_number " + "already exists.",
status_code=409,
)
success = db.shell_servers.find_one_and_update({"sid": sid}, {"$set": updates})
if not success:
return None
else:
return sid
def remove_server(sid):
"""
Remove a shell server from the pool of servers.
Args:
sid: the sid of the server to be removed
Returns:
sid of the removed shell server, or
None if the provided sid was not found
"""
db = api.db.get_conn()
res = db.shell_servers.find_one_and_delete({"sid": sid})
if res is None:
return None
else:
return sid
def get_all_servers():
"""Return the full list of shell servers."""
db = api.db.get_conn()
return list(db.shell_servers.find({}, {"_id": 0}))
def get_assigned_server():
"""Return the assigned shell server for the currently logged-in team."""
db = api.db.get_conn()
settings = api.config.get_settings()
match = {}
if settings["shell_servers"]["enable_sharding"]:
team = api.team.get_team()
match = {"server_number": team.get("server_number", 1)}
servers = list(db.shell_servers.find(match, {"_id": 0}))
if len(servers) == 0 and settings["shell_servers"]["enable_sharding"]:
raise PicoException(
"Your assigned shell server is currently down." + "Please contact an admin."
)
return servers
def get_problem_status_from_server(sid):
"""
Connect to the server and check the status of the problems running there.
Runs `sudo shell_manager status --json` and parses its output.
Closes connection after running command.
Args:
sid: The sid of the server to check
Returns:
A tuple containing:
- True if all problems are online and false otherwise
- The output data of shell_manager status --json
"""
shell = get_connection(sid)
with shell:
output = shell.run(
["sudo", "/picoCTF-env/bin/shell_manager", "status", "--json"],
encoding="utf-8",
).output
data = json.loads(output)
all_online = True
for problem in data["problems"]:
for instance in problem["instances"]:
# if the service is not working
if not instance["service"]:
all_online = False
# if the connection is not working and it is a remote challenge
if not instance["connection"] and instance["port"] is not None:
all_online = False
return (all_online, data)
def get_publish_output(sid):
"""
Connect to the server and capture the `shell_manager publish` output.
Args:
sid: the shell server ID to run the command on
Returns:
the output as a dict
"""
shell = get_connection(sid)
with shell:
status = shell.run(
["sudo", "/picoCTF-env/bin/shell_manager", "status"],
allow_error=True,
encoding="utf-8",
)
if status.return_code != 0:
raise PicoException(
"Not all instances online, check shell_manager.",
data={"stderr": status.stderr_output},
)
result = shell.run(
["sudo", "/picoCTF-env/bin/shell_manager", "publish"], encoding="utf-8"
)
return json.loads(result.output)
def get_assigned_server_number(new_team=True, tid=None):
"""
Assign a server number based on current team count and configured stepping.
Returns:
(int) server_number
"""
settings = api.config.get_settings()["shell_servers"]
db = api.db.get_conn()
if new_team:
team_count = db.teams.count()
else:
if not tid:
raise PicoException("tid must be specified.")
oid = db.teams.find_one({"tid": tid}, {"_id": 1})
if not oid:
raise PicoException("Invalid tid.")
team_count = db.teams.count({"_id": {"$lt": oid["_id"]}})
assigned_number = 1
steps = settings["steps"]
if steps:
if team_count < steps[-1]:
for i, step in enumerate(steps):
if team_count < step:
assigned_number = i + 1
break
else:
assigned_number = (
1
+ len(steps)
+ (team_count - steps[-1]) // settings["default_stepping"]
)
else:
assigned_number = team_count // settings["default_stepping"] + 1
if settings["limit_added_range"]:
max_number = list(
db.shell_servers.find({}, {"server_number": 1})
.sort("server_number", -1)
.limit(1)
)[0]["server_number"]
return min(max_number, assigned_number)
else:
return assigned_number
def reassign_teams(include_assigned=False):
"""Reassign teams to shell servers."""
db = api.db.get_conn()
if include_assigned:
teams = api.team.get_all_teams()
else:
teams = list(
db.teams.find({"server_number": {"$exists": False}}, {"_id": 0, "tid": 1})
)
for team in teams:
old_server_number = team.get("server_number")
server_number = get_assigned_server_number(new_team=False, tid=team["tid"])
if old_server_number != server_number:
db.teams.update(
{"tid": team["tid"]},
{"$set": {"server_number": server_number, "instances": {}}},
)
# Re-assign instances
api.problem.get_unlocked_pids(team["tid"])
return len(teams)